Tensorflow查看网络(inspect)、冻结变量(freeze)和迁移训练(finetune)
2018-01-05 14:59
3491 查看
Tensorflow查看网络、冻结变量和迁移训练
[align=center](Inspect network structure, freeze graph variables, and finetune/transfer learning in Tensorflow)[/align]1. 查看网络结构和参数
python /usr/local/lib/python2.7/dist-packages/tensorflow/python/tools/inspect_checkpoint.py --file_name=model.ckpt-1562770 --tensor_name=unit_1_2/sub1/conv1/DW
源码中的inspect_checkpoint.py可以看ckpt文件中的层和某层的权重值
如果只有--file_name就只显示层,如果还有--tensor_name就能显示那一层的权重
2. 只训练graph中部分变量(相当于冻结了其他变量)
Tensorflow在构建graph的过程中会默认自动收集一些变量名到对应的Collection。例如TRAINABLE_VARIABLES就是所有可训练的变量集合。因此可以通过使用tf.get_collection,指定TRAINABLE_VARIABLES,使其仅包含我们需要重新训练的变量,来冻结其他变量的训练。
例子如下:
first_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "unit_last") trainable_variables = first_train_vars #print trainable_variables grads = self.optimizer.compute_gradients(self.cost, self.trainable_variables)
3. 更改graph后恢复训练
根据monitored_session.py,使用MonitoredTrainingSession来开启控制Session的时候,若指定的checkpoint路径中有上次的存档,则现有源码只能严格按照之前训练恢复。因此我们需要一个空的checkpoint路径,此时MonitoredTrainingSession就会执行init_op以及init_fn。在init_fn中自己添加恢复函数,并把init_fn作为参数加入MonitoredTrainingSession中的scaffold即可。例子如下:
variables_to_restore = tf.contrib.framework.get_variables_to_restore( exclude=['logit']) init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint( ckpt.model_checkpoint_path, variables_to_restore) def InitAssignFn(scaffold, sess): sess.run(init_assign_op, init_feed_dict) scaffold = tf.train.Scaffold(saver=tf.train.Saver(), init_fn=InitAssignFn)
相关文章推荐
- TensorFlow迁移学习-使用谷歌训练好的Inception-v3网络进行分类
- tensorflow 保存训练模型ckpt 查看ckpt文件中的变量名和对应值
- tensorflow从已经训练好的模型中,恢复(指定)权重(构建新变量、网络)并继续训练(finetuning)
- matlab训练好神经网络之后,查看其权值参数。
- tensorflow保存网络参数并调用迁移参数
- tensorflow训练cnn网络实现避障与导航(二)V-rep仿真环境的搭建
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- tensorflow的基本用法(七)——使用MNIST训练神经网络
- 【TensorFlow系列】【七】单图片多标签的分类网络搭建与训练
- 查看tensorflow ckpt文件中的变量名和对应值
- 基于MNIST数据集使用TensorFlow训练一个没有隐含层的浅层神经网络
- tensorflow将训练好的模型freeze,即将权重固化到图里面,并使用该模型进行预测
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- 利用TensorFlow训练简单的二分类神经网络模型
- 使用TensorFlow训练神经网络识别MNIST数据代码
- tensorflow学习笔记四:mnist实例--用简单的神经网络来训练和测试
- 基于MNIST数据集使用TensorFlow训练一个包含一个隐含层的全连接神经网络
- 重要更新 | 谷歌发布 TensorFlow 1.4,迁移Keras,支持分布式训练
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- Tensorflow-查看保存的变量及名字