tensorflow 恢复(restore)模型的两种方式
1. 介绍
首先我们要理解TensorFlow的一个规则,首先构建计算图(graph),然后初始化graph中的data,这两步是分开的。
2. 如何恢复模型
有两种方式(这两种方式有比较大的不同):
2.1 重新使用代码构建图
举个例子(完整代码):
def build_graph(): w1 = tf.Variable([1,3,10,15],name='W1',dtype=tf.float32) w2 = tf.Variable([3,4,2,18],name='W2',dtype=tf.float32) w3 = tf.placeholder(shape=[4],dtype=tf.float32,name='W3') w4 = tf.Variable([100,100,100,100],dtype=tf.float32,name='W4') add = tf.add(w1,w2,name='add') add1 = tf.add(add,w3,name='add1') return w3,add1 with tf.Session() as sess: ckpt_state = tf.train.get_checkpoint_state('./temp/') if ckpt_state: w3,add1=build_graph() saver = tf.train.Saver() saver.restore(sess, ckpt_state.model_checkpoint_path)else: w3,add1=build_graph() init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer()) sess.run(init_op)saver = tf.train.Saver() a = sess.run(add1,feed_dict={ 3ff7 w3:[1,2,3,4] }) print(a) saver.save(sess,'./temp/model')
上面的流程很简单,首先build_graph(),然后如果有ckpt文件就从该文件中读取数据,否则用sess.run(init_op)初始化数据。
那么第一种restore方法就出来了:
build_graph() saver = tf.train.Saver() saver.restore(sess, ckpt_state.model_checkpoint_path)
首先build graph,等于是将图重新建立了一遍,和之前图的一样,然后将ckpt文件里的数据restore到图里的变量里。
当然,在build graph的过程中,你可以在原有的图里加一些变量,但是加的变量一定要初始化,但是要注意到一个问题,如果使用:
init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer()) sess.run(init_op)
这种方式时,如果定义init_op时的graph中已经存在原有图的变量,那么sess.run(init_op)会将加载进来的数据清空。
为了解决这个问题,两种方式:
-
新定义的变量放在init_op之前,在init_op之后restore(注意,加载好变量后才run(init_op)同样会覆盖)
即,init_op得到当前图中的所有变量,sess.run(init_op)对init_op中的变量进行初始化,所以什么时候定义init_op和什么时候运行run(init_op)都很重要 -
只初始化未初始化的变量
def get_uninitialized_variables(sess): global_vars = tf.global_variables() # print([str(i.name) for i in global_vars]) is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars]) not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f] print([str(i.name) for i in not_initialized_vars]) return not_initialized_vars sess.run(tf.variables_initializer(get_uninitialized_variables(sess)))
PS:注意saver = tf.train.Saver()要定义在图构建完成之后
即将被restore的变量不用初始化,但是只有在restore之后,这些变量才会被初始化,所以在restore之前运行这些值会报没有初始化的错。
2.2 利用保存的.meta文件恢复图
参考:Tensorflow如何保存、读取model (即利用训练好的模型测试新数据的准确度)
上面的方式适用于断点续训,且自己有构建图的完整代码,如果我要用别人的网络(fine tune),或者在自己原有网络上修改(即修改原有网络的某个部分),那么将网络的图重新构建一遍会很麻烦,那么我们可以直接从.meta文件中加载网络结构。
2.2.1 get_tensor_by_name
完整代码:
def build_graph(): w1 = tf.Variable([1,3,10,15],name='W1',dtype=tf.float32) w2 = tf.Variable([3,4,2,18],name='W2',dtype=tf.float32) w3 = tf.placeholder(shape=[4],dtype=tf.float32,name='W3') w4 = tf.Variable([100,100,100,100],dtype=tf.float32,name='W4') add = tf.add(w1,w2,name='add') add1 = tf.add(add,w3,name='add1') return w3,add1 with tf.Session() as sess: ckpt_state = tf.train.get_checkpoint_state('./temp/') if ckpt_state: saver = tf.train.import_meta_graph('./temp/model.meta') graph = tf.get_default_graph() w3 = graph.get_tensor_by_name('W3:0') add1 = graph.get_tensor_by_name('add1:0') saver.restore(sess, tf.train.latest_checkpoint('./temp/')) print(sess.run(tf.get_collection('w1')[0])) else: w3,add1=build_graph() init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer()) sess.run(init_op)saver = tf.train.Saver() a = sess.run(add1,feed_dict={ w3:[1,2,3,4] }) print(a) saver.save(sess,'./temp/model')
上面使用了import_meta_graph()来加载图,并用restore给变量赋值。
通过get_tensor_by_name来获取保存的图中的op或变量,之后可以对获取的值进行操作,如果之后save的话,也会将import_meta_graph()中图引用的部分保存下来。
2.2.2
def build_graph(): w1 = tf.Variable([1,3,10,15],name='W1',dtype=tf.float32) w2 = tf.Variable([3,4,2,18],name='W2',dtype=tf.float32) w3 = tf.placeholder(shape=[4],dtype=tf.float32,name='W3') w4 = tf.Variable([100,100,100,100],dtype=tf.float32,name='W4') add = tf.add(w1,w2,name='add') add1 = tf.add(add,w3,name='add1') tf.add_to_collection('w1','W1:0') tf.add_to_collection('w3',w3) tf.add_to_collection('add1',add1) return w3,add1 with tf.Session() as sess: ckpt_state = tf.train.get_checkpoint_state('./temp/') if ckpt_state: saver = tf.train.import_meta_graph('./temp/model.meta') w3 = tf.get_collection('w3')[0] add1 = tf.get_collection('add1')[0] # run init_op before restore saver.restore(sess, tf.train.latest_checkpoint('./temp/')) else: w3,add1=build_graph() init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer()) sess.run(init_op)saver = tf.train.Saver() a = sess.run(add1,feed_dict={ w3:[1,2,3,4] }) print(a) saver.save(sess,'./temp/model')
通过import_meta_graph引进图,通过get_collection获得变量,其实和get_tensor_by_name差不多,但是可能会更方便一点。
3. 总结
总的来说,两种方式都是先构造好图,然后通过restore来给图里的变量赋值。
一个常见的问题是,要引入新的变量,对以前的图进行改造,那么如何初始化新的变量且不覆盖原来的数据?
- 可以先啥都不管把所有的图相关的部分构造好后,得到init_op,然后在restore前run(init_op)
- 对未初始化的变量进行初始化
4. 最后
- Tensorflow 存储和恢复模型 (save restore)
- tensorflow模型的保存与恢复(tf.train.Saver()和saver.restore()方法的运用)
- struts2 接受参数的两种方式(属性驱动和模型驱动)
- tensorflow中关于模型存储和恢复(tf.train.Saver())的问题
- tensorflow中模型的保存和恢复
- tensorflow模型的存储与恢复
- TensorFlow 模型保存与恢复总结(微调、微改已有模型)
- 16、TensorFLow 模型参数的保存与恢复
- AspNetCore 文件上传(模型绑定、Ajax) 两种方式 get到了吗?
- TensorFlow模型保存/载入的两种方法
- tensorflow学习笔记--restore使用模型
- Tensorflow模型的保存与恢复
- tensorflow保存模型与恢复数据
- Tensorflow中两种padding方式“SAME”和“VALID”
- 一个快速完整的教程,以保存和恢复Tensorflow模型。(转)
- tensorflow 学习笔记(十一)- 模型的保存与恢复(Saver)
- Tensorflow保存模型,恢复模型,使用训练好的模型进行预测和提取中间输出(特征)【转】
- 在AngularJS中显示模型中的数据有两种方式:
- TensorFlow保存以及恢复模型找到特定张量以及操作
- tensorflow从已经训练好的模型中,恢复(指定)权重(构建新变量、网络)并继续训练(finetuning)