TensorFlow模型op的保存和加载(含演示代码)
2017-11-25 20:37
597 查看
上一篇博文《TensorFlow模型参数的保存和加载》介绍了如何保存和加载TensorFlow模型训练参数,保存对象主要是Tensor/Variables。这一节我们介绍如何保存和复用op。
和Tensor一样,保存op需要在训练时为op指定名字,如下所示:
softmax = tf.nn.softmax(tf.matmul(x, W) + b,name="op_softmax")
在识别阶段,调用get_operation_by_name()函数,以op名字作为参数,如下所示:
op_softmax =sess.graph.get_operation_by_name("op_softmax").outputs[0]
这里需要注意的是,在最后面要加上“.outputs[0]”,否则会出现异常。
如果要直接运行sess.run(op_softmax),需要指定feed_dict。以官方mnist训练案例为例,调用格式为sess.run(op_softmax,feed_dict={x: mnist.test.images})。
在加载复用阶段,W和b的值已经保存在checkpoint数据中,故不需要再次声明W和b。但是,需要通过get_tensor_by_name()获取到x的声明,如下所示:
x = sess.graph.get_tensor_by_name("x:0")
上述op保存和加载操作可总结为:
1. 在训练代码中,为op指定名字;
2. 在复用阶段,通过get_operation_by_name().outputs[0]获取op;
3. 通过get_tensor_by_name()获取到feed_dict输入tensor(即x)的声明,然后执行sess.run(op, feed_dict={x:new_data})。
完整演示代码如下:
假设脚本文件名为op-restore.py,则训练时启动命令为:
识别时启动命令为:
当第一次运行该脚本的时候,如果当前目录没有mnist数据集,则会自动下载数据集,如果网络不稳定,那么下载过程会很缓慢。为方便使用,我把mnist数据集上传到百度网盘,链接地址如下:
http://pan.baidu.com/s/1c2k3gkw
下载后将整个MNIST_data文件夹放到脚本同一目录,运行时就不会触发下载了。
和Tensor一样,保存op需要在训练时为op指定名字,如下所示:
softmax = tf.nn.softmax(tf.matmul(x, W) + b,name="op_softmax")
在识别阶段,调用get_operation_by_name()函数,以op名字作为参数,如下所示:
op_softmax =sess.graph.get_operation_by_name("op_softmax").outputs[0]
这里需要注意的是,在最后面要加上“.outputs[0]”,否则会出现异常。
如果要直接运行sess.run(op_softmax),需要指定feed_dict。以官方mnist训练案例为例,调用格式为sess.run(op_softmax,feed_dict={x: mnist.test.images})。
在加载复用阶段,W和b的值已经保存在checkpoint数据中,故不需要再次声明W和b。但是,需要通过get_tensor_by_name()获取到x的声明,如下所示:
x = sess.graph.get_tensor_by_name("x:0")
上述op保存和加载操作可总结为:
1. 在训练代码中,为op指定名字;
2. 在复用阶段,通过get_operation_by_name().outputs[0]获取op;
3. 通过get_tensor_by_name()获取到feed_dict输入tensor(即x)的声明,然后执行sess.run(op, feed_dict={x:new_data})。
完整演示代码如下:
#!/usr/bin/python3.5 # -*- coding: utf-8 -*- import os import sys import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data SAVER_DIR = "train-saver/" mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) print ('本脚本须输入参数save或restore') print ('如果当前目录下没有MNIST_data数据,可能需要花费几分钟等待mnist数据下载') print ('如果下载缓慢,可以从百度网盘http://pan.baidu.com/s/1c2k3gkw直接下载,放到运行脚本同一目录下即可') if __name__ =='__main__' and sys.argv[1]=='save': x = tf.placeholder(tf.float32, [None, 784], name="x") labels = tf.placeholder(tf.float32, [None, 10], name="labels") W = tf.Variable(tf.zeros([784, 10]), name="var_W") b = tf.Variable(tf.zeros([10]), name="var_b") # 构建网络op softmax = tf.nn.softmax(tf.matmul(x, W) + b, name="op_softmax") cross_entropy = tf.reduce_mean(-tf.reduce_sum(labels * tf.log(softmax), reduction_indices=[1])) train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) sess = tf.InteractiveSession() tf.global_variables_initializer().run() for i in range(1000): if i%10 == 0: print ("正在进行第 %d 次训练迭代......" % (i)) batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, labels: batch_ys}) # 保存训练结果 if not os.path.exists(SAVER_DIR): print ('不存在训练数据保存目录,现在创建保存目录') os.makedirs(SAVER_DIR) # 初始化saver saver = tf.train.Saver() saver_path = saver.save(sess, "%smodel.ckpt"%(SAVER_DIR)) print ('mnist手写体数字训练结果已保存!') if __name__ =='__main__' and sys.argv[1]=='restore': sess = tf.InteractiveSession() # 导入保存训练结果的文件 saver = tf.train.import_meta_graph("%smodel.ckpt.meta"%(SAVER_DIR)) model_file=tf.train.latest_checkpoint(SAVER_DIR) saver.restore(sess, model_file) # 通过指定变量名获取训练结果中的变量值 x = sess.graph.get_tensor_by_name("x:0") # 执行识别 op_softmax = sess.graph.get_operation_by_name("op_softmax").outputs[0] recog = tf.argmax(op_softmax,1) print("mnist手写体数字测试集识别结果为:%s" % (sess.run(recog, feed_dict={x: mnist.test.images})))
假设脚本文件名为op-restore.py,则训练时启动命令为:
python op-restore.py save
识别时启动命令为:
python op-restore.py restore
当第一次运行该脚本的时候,如果当前目录没有mnist数据集,则会自动下载数据集,如果网络不稳定,那么下载过程会很缓慢。为方便使用,我把mnist数据集上传到百度网盘,链接地址如下:
http://pan.baidu.com/s/1c2k3gkw
下载后将整个MNIST_data文件夹放到脚本同一目录,运行时就不会触发下载了。
相关文章推荐
- TensorFlow模型参数的保存和加载(含演示代码)
- Tensorflow加载预训练模型和保存模型
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- TensorFlow构建卷积神经网络/模型保存与加载/正则化
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- TensorFlow模型保存和加载方法
- tensorflow模型保存与加载
- TensorFlow 训练好模型参数的保存和恢复代码
- tensorflow初探---模型文件保存和加载
- tensorflow中保存模型、加载模型做预测(不需要再定义网络结构)
- tensorflow加载预训练模型的时候报错:ValueError:No OP Named DecodeBmp in difined operations的解决
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- Tensorflow模型的保存和加载
- TensorFlow保存和加载训练模型
- python使用tensorflow保存、加载和使用模型的方法
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- TensorFlow保存和加载训练模型
- tensorflow-模型保存和加载(一)
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载