tensorflow:从.ckpt文件中读取任意变量
2018-07-06 10:52
543 查看
版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/weixin_39999955/article/details/80937112
思路有些混乱,希望大家能理解我的意思。
看了faster rcnn的tensorflow代码,关于fix_variables的作用我不是很明白,所以写了以下代码,读取了预训练模型vgg16得fc6和fc7的参数,以及faster rcnn中heat_to_tail中的fc6和fc7,将它们做了对比,发现结果不一样,说明vgg16的fc6和fc7只是初始化了faster rcnn中heat_to_tail中的fc6和fc7,之后后者被训练。
具体读取任意变量的代码如下:
import tensorflow as tf import numpy as np from tensorflow.python import pywrap_tensorflow file_name = '/home/dl/projectBo/tf-faster-rcnn/data/imagenet_weights/vgg16.ckpt' #.ckpt的路径 name_variable_to_restore = 'vgg_16/fc7/weights' #要读取权重的变量名 reader = pywrap_tensorflow.NewCheckpointReader(file_name) var_to_shape_map = reader.get_variable_to_shape_map() print('shape', var_to_shape_map[name_variable_to_restore]) #输出这个变量的尺寸 fc7_conv = tf.get_variable("fc7", var_to_shape_map[name_variable_to_restore], trainable=False) # 定义接收权重的变量名 restorer_fc = tf.train.Saver({name_variable_to_restore: fc7_conv }) #定义恢复变量的对象 sess = tf.Session() sess.run(tf.variables_initializer([fc7_conv], name='init')) #必须初始化 restorer_fc.restore(sess, file_name) #恢复变量 print(sess.run(fc7_conv)) #输出结果用以上的代码分别读取两个网络的fc6 和 fc7 ,对应参数尺寸和权值都不同,但参数量相同。
再看lib/nets/vgg16.py中的:
(注意注释)def fix_variables(self, sess, pretrained_model): print('Fix VGG16 layers..') with tf.variable_scope('Fix_VGG16') as scope: with tf.device("/cpu:0"): # fix the vgg16 issue from conv weights to fc weights # fix RGB to BGR fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False) fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False) conv1_rgb = tf.get_variable("conv1_rgb", [3, 3, 3, 64], trainable=False) #定义接收权重的变量,不可被训练 restorer_fc = tf.train.Saver({self._scope + "/fc6/weights": fc6_conv, self._scope + "/fc7/weights": fc7_conv, self._scope + "/conv1/conv1_1/weights": conv1_rgb}) #定义恢复变量的对象 restorer_fc.restore(sess, pretrained_model) #恢复这些变量 sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc6/weights:0'], tf.reshape(fc6_conv, self._variables_to_fix[self._scope + '/fc6/weights:0'].get_shape()))) sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc7/weights:0'], tf.reshape(fc7_conv, self._variables_to_fix[self._scope + '/fc7/weights:0'].get_shape()))) sess.run(tf.assign(self._variables_to_fix[self._scope + '/conv1/conv1_1/weights:0'], tf.reverse(conv1_rgb, [2]))) #将vgg16中的fc6、fc7中的权重reshape赋给faster-rcnn中的fc6、fc7我的理解:faster rcnn的网络继承了分类网络的特征提取权重和分类器的权重,让网络从一个比较好的起点开始被训练,有利于训练结果的快速收敛。 阅读更多
相关文章推荐
- 查看tensorflow ckpt文件中的变量名和对应值
- 查看tensorflow ckpt文件中的变量名和对应值
- 从TensorFlow的.ckpt文件中读取网络的参数
- tensorflow 保存训练模型ckpt 查看ckpt文件中的变量名和对应值
- Tensorflow: 读取ckpt文件中的tensor
- 【TensorFlow系列】【二】如何从ckpt文件中拷贝权值到新的变量中
- TensorFlow模型文件保存和读取
- 一次性将整个文件读取到string变量中
- PhpMyadmin任意文件读取漏洞
- java读取项目根路径下和任意磁盘位置下的properties文件
- VS2012中,C# 配置文件读取 + C#多个工程共享共有变量 + 整理using语句
- git 漏洞导致任意文件读取
- jmeter之读取环境变量中的配置文件
- spring-framework 配置文件读取系统变量
- Hadoop Writable深度复制及读取任意<key,value>序列文件
- FFmpeg任意文件读取漏洞分析
- 任意文件读取
- TensorFlow的文件保存与读取——variables_to_restore函数
- 【漏洞预警】FFmpeg曝任意文件读取漏洞
- 从文件中读取结构体变量的数据读取和写入结构体数据到文件