您的位置:首页 > 理论基础 > 计算机网络

Tensorflow查看网络(inspect)、冻结变量(freeze)和迁移训练(finetune)

2018-01-05 14:59 3491 查看

Tensorflow查看网络、冻结变量和迁移训练

[align=center](Inspect network structure, freeze graph variables, and finetune/transfer learning in Tensorflow)[/align]
 

1.    查看网络结构和参数

python
/usr/local/lib/python2.7/dist-packages/tensorflow/python/tools/inspect_checkpoint.py
--file_name=model.ckpt-1562770
--tensor_name=unit_1_2/sub1/conv1/DW


源码中的inspect_checkpoint.py可以看ckpt文件中的层和某层的权重值

如果只有--file_name就只显示层,如果还有--tensor_name就能显示那一层的权重

 

2.    只训练graph中部分变量(相当于冻结了其他变量)

Tensorflow在构建graph的过程中会默认自动收集一些变量名到对应的Collection。例如TRAINABLE_VARIABLES就是所有可训练的变量集合。

因此可以通过使用tf.get_collection,指定TRAINABLE_VARIABLES,使其仅包含我们需要重新训练的变量,来冻结其他变量的训练。

例子如下:

first_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
"unit_last")
trainable_variables = first_train_vars
#print trainable_variables
grads = self.optimizer.compute_gradients(self.cost, self.trainable_variables)


 

3.    更改graph后恢复训练

根据monitored_session.py,使用MonitoredTrainingSession来开启控制Session的时候,若指定的checkpoint路径中有上次的存档,则现有源码只能严格按照之前训练恢复。因此我们需要一个空的checkpoint路径,此时MonitoredTrainingSession就会执行init_op以及init_fn。在init_fn中自己添加恢复函数,并把init_fn作为参数加入MonitoredTrainingSession中的scaffold即可。

例子如下:

variables_to_restore = tf.contrib.framework.get_variables_to_restore(
exclude=['logit'])
init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint(
ckpt.model_checkpoint_path, variables_to_restore)
def InitAssignFn(scaffold, sess):
sess.run(init_assign_op, init_feed_dict)
scaffold = tf.train.Scaffold(saver=tf.train.Saver(), init_fn=InitAssignFn)


 
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
相关文章推荐