您的位置:首页 > 其它

TensorFlow的文件保存与读取——variables_to_restore函数

2018-03-21 12:08 309 查看
转,原创详见: http://blog.csdn.net/sinat_29957455/article/details/78508793
variables_to_restore函数,是TensorFlow为滑动平均值提供。之前,也介绍过通过使用滑动平均值可以让神经网络模型更加的健壮。我们也知道,其实在TensorFlow中,变量的滑动平均值都是由影子变量所维护的,如果你想要获取变量的滑动平均值需要获取的是影子变量而不是变量本身。
1、滑动平均值模型文件的保存

[python] view plain copyimport tensorflow as tf  
  
if __name__ == "__main__":  
    v = tf.Variable(0.,name="v")  
    #设置滑动平均模型的系数  
    ema = tf.train.ExponentialMovingAverage(0.99)  
    #设置变量v使用滑动平均模型,tf.all_variables()设置所有变量  
    op = ema.apply([v])  
    #获取变量v的名字  
    print(v.name)  
    #v:0  
    #创建一个保存模型的对象  
    save = tf.train.Saver()  
    sess = tf.Session()  
    #初始化所有变量  
    init = tf.initialize_all_variables()  
    sess.run(init)  
    #给变量v重新赋值  
    sess.run(tf.assign(v,10))  
    #应用平均滑动设置  
    sess.run(op)  
    #保存模型文件  
    save.save(sess,"./model.ckpt")  
    #输出变量v之前的值和使用滑动平均模型之后的值  
    print(sess.run([v,ema.average(v)]))  
    #[10.0, 0.099999905]  
上面的代码,是如何来保存一个滑动平均值的模型文件,之前有介绍过滑动平均值和模型文件的保存,所以这里就不再重复了。
2、滑动平均值模型文件的读取

[python] view plain copyv = tf.Variable(1.,name="v")  
#定义模型对象  
saver = tf.train.Saver({"v/ExponentialMovingAverage":v})  
sess = tf.Session()  
saver.restore(sess,"./model.ckpt")  
print(sess.run(v))  
#0.0999999  
对于模型文件的读取,在上一篇博客中有介绍过,这里特别需要注意的一个地方就是,在使用tf.train.Saver函数中,所传递的模型参数是{"v/ExponentialMovingAverage":v}而不是{"v":v},如果你使用的是后面的参数,那么你得到的结果将是10而不是0.09,那是因为后者获取的是变量本身而不是影子变量。是不是感觉使用这种方式来读取模型文件的时候,还需要输入一大串的变量名称。
3、variables_to_restore函数的使用

[pyth
a7f1
on]
 view plain copyv = tf.Variable(1.,name="v")  
#滑动模型的参数的大小并不会影响v的值  
ema = tf.train.ExponentialMovingAverage(0.99)  
print(ema.variables_to_restore())  
#{'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}  
sess = tf.Session()  
saver = tf.train.Saver(ema.variables_to_restore())  
saver.restore(sess,"./model.ckpt")  
print(sess.run(v))  
#0.0999999  
通过使用variables_to_restore函数,可以使在加载模型的时候将影子变量直接映射到变量的本身,所以我们在获取变量的滑动平均值的时候只需要获取到变量的本身值而不需要去获取影子变量。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: