tensorflow的基本用法(十)——保存神经网络参数和加载神经网络参数
2017-04-20 19:57
393 查看
文章作者:Tyan
博客:noahsnail.com | CSDN | 简书
本文主要是使用tensorfl保存神经网络参数和加载神经网络参数。
执行结果如下:
博客:noahsnail.com | CSDN | 简书
本文主要是使用tensorfl保存神经网络参数和加载神经网络参数。
#!/usr/bin/env python # _*_ coding: utf-8 _*_ import tensorflow as tf import numpy as np # 保存神经网络参数 def save_para(): # 定义权重参数 W = tf.Variable([[1, 2, 3], [4, 5, 6]], dtype = tf.float32, name = 'weights') # 定义偏置参数 b = tf.Variable([[1, 2, 3]], dtype = tf.float32, name = 'biases') # 参数初始化 init = tf.global_variables_initializer() # 定义保存参数的saver saver = tf.train.Saver() with tf.Session() as sess: sess.run(init) # 保存session中的数据 save_path = saver.save(sess, 'my_net/save_net.ckpt') # 输出保存路径 print 'Save to path: ', save_path # 恢复神经网络参数 def restore_para(): # 定义权重参数 W = tf.Variable(np.arange(6).reshape((2, 3)), dtype = tf.float32, name = 'weights') # 定义偏置参数 b = tf.Variable(np.arange(3).reshape((1, 3)), dtype = tf.float32, name = 'biases') # 定义提取参数的saver saver = tf.train.Saver() with tf.Session() as sess: # 加载文件中的参数数据,会根据name加载数据并保存到变量W和b中 save_path = saver.restore(sess, 'my_net/save_net.ckpt') # 输出保存路径 print 'Weights: ', sess.run(W) print 'biases: ', sess.run(b) # save_para() restore_para()
执行结果如下:
# save Save to path: my_net/save_net.ckpt # restore Weights: [[ 1. 2. 3.] [ 4. 5. 6.]] biases: [[ 1. 2. 3.]]
参考资料
https://www.youtube.com/user/MorvanZhou相关文章推荐
- tensorflow的基本用法(六)——神经网络可视化
- tensorflow的基本用法(七)——使用MNIST训练神经网络
- tensorflow的基本用法(五)——创建神经网络并训练
- scikit-learn的基本用法——模型保存与加载
- Tensorflow基本语法和实现神经网络
- tensorflow保存网络参数并调用迁移参数
- TensorFlow学习(七):基本神经网络"组件"
- TF Saver 保存/加载训练好模型(网络+参数)的那些事儿
- tensorflow保存网络参数 使用训练好的网络参数进行数据的预测
- 学习TensorFlow,保存学习到的网络结构参数并调用
- 学习TensorFlow,保存学习到的网络结构参数并调用
- 如何保存训练好的神经网络直接进行测试-TensorFlow模型持久化
- scikit-learn的基本用法(八)——模型保存与加载
- tensorflow实现最基本的神经网络 + 对比GD、SGD、batch-GD的训练方法
- 【深度学习】tensorflow加载VGG16的网络结构和模型参数
- 神经网络 tensorflow教程 2.2 下载MNIST 数据集(保存所有图片)
- C++从零实现深度神经网络之五——模型的保存和加载以及画出实时输出曲线
- tensorflow: 保存和加载模型, 参数;以及使用预训练参数方法
- TensorFlow中对训练后的神经网络参数(权重、偏置)提取
- tensorflow 神经网络基本使用