您的位置:首页 > 其它

保存和加载pb模型

2018-01-31 22:20 375 查看
将模型保存为pb

import tensorflow as tf
from tensorflow.python.framework import graph_util

logdir='output/'

with tf.variable_scope('conv'):
w=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer)
b=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer)

sess=tf.InteractiveSession()

tf.global_variables_initializer().run() # 初始化所有变量
constant_graph_w = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["conv/w"])
constant_graph_b = graph_util.convert_variables_to_constants(sess , sess.graph_def , ['conv/b'])

with tf.gfile.FastGFile(logdir+'expert_graph.pb', mode='wb') as f:
f.write(constant_graph_w.SerializeToString())
f.write(constant_graph_b.SerializeToString())

sess.close()


加载pb模型

import tensorflow as tf
from tensorflow.python.framework import graph_util

logdir = 'output/'
output_graph_path = logdir+'expert_graph.pb'
with tf.Session() as sess:
tf.global_variables_initializer().run()
output_graph_def = tf.GraphDef()
with open(output_graph_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")

w = sess.graph.get_tensor_by_name("conv/w:0")
print('w:' , w.eval())

b = sess.graph.get_tensor_by_name("conv/b:0")
print('b:' , b.eval())
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: