Tensorflow:模型保存/模型恢复?
2017-11-15 21:55
751 查看
仅供学习参考,转载地址:https://vimsky.com/article/3614.html
在Tensorflow中训练一个模型之后:如何保存训练得到的模型?
如何恢复(重新加载)这个保存的模型?
最佳解决办法
为保存和恢复模型添加更多细节功能,下面的答案在持续改进中。对Tensorflow版本0.11以及之后的版本:
保存模型:
import tensorflow as tf #Prepare to feed input, i.e. feed_dict and placeholders w1 = tf.placeholder("float", name="w1") w2 = tf.placeholder("float", name="w2") b1= tf.Variable(2.0,name="bias") feed_dict ={w1:4,w2:8} #Define a test operation that we will restore w3 = tf.add(w1,w2) w4 = tf.multiply(w3,b1,name="op_to_restore") sess = tf.Session() sess.run(tf.global_variables_initializer()) #Create a saver object which will save all the variables saver = tf.train.Saver() #Run the operation by feeding input print sess.run(w4,feed_dict) #Prints 24 which is sum of (w1+w2)*b1 #Now, save the graph saver.save(sess, 'my_test_model',global_step=1000)
恢复模型(重新加载模型):
import tensorflow as tf sess=tf.Session() #First let's load meta graph and restore weights saver = tf.train.import_meta_graph('my_test_model-1000.meta') saver.restore(sess,tf.train.latest_checkpoint('./')) # Access saved Variables directly print(sess.run('bias:0')) # This will print 2, which is the value of bias that we saved # Now, let's access and create placeholders variables and # create feed-dict to feed new data graph = tf.get_default_graph() w1 = graph.get_tensor_by_name("w1:0") w2 = graph.get_tensor_by_name("w2:0") feed_dict ={w1:13.0,w2:17.0} #Now, access the op that you want to run. op_to_restore = graph.get_tensor_by_name("op_to_restore:0") print sess.run(op_to_restore,feed_dict) #This will print 60 which is calculated
想了解更多信息可以参考:
http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/
次佳解决办法
对于TensorFlow版本0.11.0RC1以及之后的版本,可以直接通过调用tf.train.export_meta_graph和
tf.train.import_meta_graph(根据https://www.tensorflow.org/programmers_guide/meta_graph)保存和恢复模型
保存模式:
w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1') w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2') tf.add_to_collection('vars', w1) tf.add_to_collection('vars', w2) saver = tf.train.Saver() sess = tf.Session() sess.run(tf.global_variables_initializer()) saver.save(sess, 'my-model') # `save` method will call `export_meta_graph` implicitly. # you will get saved graph files:my-model.meta
恢复模式:
sess = tf.Session() new_saver = tf.train.import_meta_graph('my-model.meta') new_saver.restore(sess, tf.train.latest_checkpoint('./')) all_vars = tf.get_collection('vars') for v in all_vars: v_ = sess.run(v) print(v_)
第三种解决办法
对于TensorFlow版本< 0.11.0RC1:保存的检查点包含模型中的
Variable们的值,而不是模型/图形本身,这意味着恢复检查点时对于图形应该一样。
下面是一个线性回归的例子,其中有一个保存变量检查点的训练循环和一个评估部分,它将恢复在之前的运行中保存的变量并计算预测结果。当然,也可以恢复变量并继续进行训练。
x = tf.placeholder(tf.float32) y = tf.placeholder(tf.float32) w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32)) b = tf.Variable(tf.ones([1, 1], dtype=tf.float32)) y_hat = tf.add(b, tf.matmul(x, w)) ...more setup for optimization and what not... saver = tf.train.Saver() # defaults to saving all variables - in this case w and b with tf.Session() as sess: sess.run(tf.initialize_all_variables()) if FLAGS.train: for i in xrange(FLAGS.training_steps): ...training loop... if (i + 1) % FLAGS.checkpoint_steps == 0: saver.save(sess, FLAGS.checkpoint_dir + 'model.ckpt', global_step=i+1) else: # Here's where you're restoring the variables w and b. # Note that the graph is exactly as it was when the variables were # saved in a prior training run. ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) else: ...no checkpoint found... # Now you can run the model to get predictions batch_x = ...load some data... predictions = sess.run(y_hat, feed_dict={x: batch_x})
以下是关于
Variable的文档docs,其中包括保存和恢复。这里是关于
Saver的文档docs。
第四种办法
模型有两部分,第一部分:模型定义,由Supervisor作为模型目录中的
graph.pbtxt保存;第二部分:张量的数值,保存到
model.ckpt-1003418等检查点文件中。
可以使用
tf.import_graph_def恢复模型定义,并使用
Saver恢复权重。
然而,
Saver使用绑定到模型Graph的特殊集合保存变量列表,并且该集合不是使用import_graph_def初始化的,所以不能一起使用这两个(未来会修复这个问题)。目前,还必须手动构建具有相同节点名称的图,并使用
Saver将权重加载到其中。
(或者,您可以使用
import_graph_def,手动创建变量,并为每个变量使用
tf.add_to_collection(tf.GraphKeys.VARIABLES, variable),然后使用
Saver)
第五种办法
也可以采取更简单的方法:步骤1 - 初始化所有变量
W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1") B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1") Similarly, W2, B2, W3, .....
步骤2 - 将模型
Saver中的列表
model_saver = tf.train.Saver() # Train the model and save it in the end model_saver.save(session, "saved_models/CNN_New.ckpt")
步骤3 - 恢复模型(重新加载模型)
with tf.Session(graph=graph_cnn) as session: model_saver.restore(session, "saved_models/CNN_New.ckpt") print("Model restored.") print('Initialized')
步骤4 - 检查变量
W1 = session.run(W1) print(W1)
当在不同的python实例中运行时,使用
with tf.Session() as sess: # Restore latest checkpoint saver.restore(sess, tf.train.latest_checkpoint('saved_model/.')) # Initalize the variables sess.run(tf.global_variables_initializer()) # Get default graph (supply your custom graph if you have one) graph = tf.get_default_graph() # It will give tensor object W1 = graph.get_tensor_by_name('W1:0') # To get the value (numpy array) W1_value = session.run(W1)
第六种办法
可以通过导入Graph,手动创建变量,然后使用保护程序,从graph_def和检查点中进行恢复。实现的代码如下:
链接:https://gist.github.com/nikitakit/6ef3b72be67b86cb7868
(这当然是一种hack的方式,并不能保证这样保存的模型在以后版本的TensorFlow中保持可读。)
第七种办法
如果是一个内部保存的模型,那么只需为所有变量指定恢复器即可restorer = tf.train.Saver(tf.all_variables())
并使用它来恢复当前会话中的变量:
restorer.restore(self._sess, model_file)
对于外部模型,需要指定从外部变量名称到本地变量名称的映射。可以使用该命令查看模型变量名称
python /path/to/tensorflow/tensorflow/python/tools/inspect_checkpoint.py --file_name=/path/to/pretrained_model/model.ckpt
inspect_checkpoint.py脚本可以在Tensorflow源的'./tensorflow/python/tools'文件夹中找到。
要指定映射,可以使用Tensorflow-Worklab,它包含一组类和脚本来训练和重新训练不同的模型。还包括一个重新训练ResNet模型的例子,位于这里
第八种办法
在大多数情况下,使用tf.train.Saver从磁盘保存和恢复是最好的选择:
... # build your model saver = tf.train.Saver() with tf.Session() as sess: ... # train the model saver.save(sess, "/tmp/my_great_model") with tf.Session() as sess: saver.restore(sess, "/tmp/my_great_model") ... # use the model
还可以保存/恢复graph结构(有关详细信息,请参阅MetaGraph documentation)。默认情况下,
Saver将graph结构保存到
.meta文件中。可以调用
import_meta_graph()来恢复它。恢复graph结构并返回一个可用于恢复模型状态的
Saver:
saver = tf.train.import_meta_graph("/tmp/my_great_model.meta") with tf.Session() as sess: saver.restore(sess, "/tmp/my_great_model") ... # use the model
但是,有些情况需要更快的速度。例如,如果要实现早期停止,则希望在训练期间(如验证集中测量)每次改进模型时保存检查点,那么如果某段时间内没有进展,则要回滚到最佳模型。如果将模型保存到磁盘上,并且每次都有提升的情况下,这将大大减慢训练速度。诀窍是将变量状态保存到内存中,然后稍后恢复它们:
... # build your model # get a handle on the graph nodes we need to save/restore the model graph = tf.get_default_graph() gvars = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) assign_ops = [graph.get_operation_by_name(v.op.name + "/Assign") for v in gvars] init_values = [assign_op.inputs[1] for assign_op in assign_ops] with tf.Session() as sess: ... # train the model # when needed, save the model state to memory gvars_state = sess.run(gvars) # when needed, restore the model state feed_dict = {init_value: val for init_value, val in zip(init_values, gvars_state)} sess.run(assign_ops, feed_dict=feed_dict)
快速说明:创建变量
X时,TensorFlow会自动创建一个赋值操作
X/Assign来设置变量的初始值。我们只需使用这些现有的赋值操作,而不是创建占位符和额外的赋值操作(这只会使graph变乱)。每个赋值op的第一个输入是对应该初始化的变量的引用,第二个输入(
assign_op.inputs[1])是初始值。所以为了设置我们想要的任何值(而不是初始值),需要使用
feed_dict并替换初始值。TensorFlow可以为任何操作提供一个值,而不仅仅是占位符。
第九种办法
这是两个基本情况的简单解决方案,不同之处在于是否要从文件加载graph或在运行时构建graph。这个答案适用于Tensorflow 0.12+(包括1.0)。
在代码中重建graph
保存
graph = ... # build the graph saver = tf.train.Saver() # create the saver after the graph with ... as sess: # your session object saver.save(sess, 'my-model')
加载
graph = ... # build the graph saver = tf.train.Saver() # create the saver after the graph with ... as sess: # your session object saver.restore(sess, tf.train.latest_checkpoint('./')) # now you can use the graph, continue training or whatever
从文件加载graph
使用此技术时,请确保所有层/变量都已明确设置唯一的名称,否则Tensorflow将自己创建独一无二的名称,这回导致与存储在文件中的名称不同。这在以前的技术中不是问题,因为在加载和保存时,名称都是"mangled"。
保存
graph = ... # build the graph for op in [ ... ]: # operators you want to use after restoring the model tf.add_to_collection('ops_to_restore', op) saver = tf.train.Saver() # create the saver after the graph with ... as sess: # your session object saver.save(sess, 'my-model')
加载
with ... as sess: # your session object saver = tf.train.import_meta_graph('my-model.meta') saver.restore(sess, tf.train.latest_checkpoint('./')) ops = tf.get_collection('ops_to_restore') # here are your operators in the same order in which you saved them to the collection
相关文章推荐
- tensorflow1.0学习之模型的保存与恢复(Saver)
- TensorFlow_MNIST 保存、恢复模型及参数
- tensorflow 1.0 学习:模型的保存与恢复(Saver)
- TensorFlow 模型的保存与恢复
- [tensorflow] tensorflow 1.0 学习:模型的保存与恢复(Saver)
- Tensorflow学习(6)模型的保存与恢复(saver)
- TensorFlow 模型保存与恢复
- TensorFlow模型保存和恢复简单的例子
- TensorFlow保存和恢复模型的方法总结
- Tensorflow学习笔记:CNN篇(7)——Finetuning,模型的保存与恢复
- TensorFlow 训练好模型参数的保存和恢复代码
- Tensorflow保存模型,恢复模型,使用训练好的模型进行预测和提取中间输出(特征)
- 170615 windows 下 tensorflow1.2.0rc2 模型的保存与恢复
- Tensorflow学习笔记:模型训练数据的保存和恢复的简单实例
- tensorflow学习笔记--模型保存和恢复
- 一份快速完整的Tensorflow模型保存和恢复教程(译)
- tensorflow中模型的保存和恢复
- TensorFlow学习(十二):模型的保存与恢复
- tensorflow保存模型和恢复模型的方法
- tensorflow 模型的保存与恢复(Saver)