您的位置:首页 > 理论基础 > 计算机网络

tensorflow 学习笔记10 网络模型的保存与提取

2017-08-19 18:35 501 查看
参数的保存与提取关键点就是前后参数的shape,name,dtype都必须一致:

参数的保存:

import tensorflow as tf
w = tf.Variable(tf.constant(1.0, shape=[1]), name="w")
b = tf.Variable(tf.constant(2.0, shape=[1]), name="b")
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.save(sess, "Model/model.ckpt")

参数的提取:
import tensorflow as tf
w = tf.Variable(tf.constant(0.0, shape=[1]), name="w")
b = tf.Variable(tf.constant(0.0, shape=[1]), name="b")
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, "./Model/model.ckpt")
print("w,b:",sess.run(w),sess.run(b))结果:

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: