tensorflow从0开始(6)——保存加载模型
2016-06-20 16:52
676 查看
目的
学习tensorflow的目的是能够训练的模型,并且利用已经训练好的模型对新数据进行预测。下文就是一个简单的保存模型加载模型的过程。
保存模型
import tensorflow as tf import os import numpy as np from tensorflow.python.platform import gfile flags = tf.app.flags FLAGS = flags.FLAGS flags.DEFINE_string('summaries_dir', '/tmp/save_graph_logs', 'Summaries directory') data = np.arange(10,dtype=np.int32) with tf.Session() as sess: print("# build graph and run") input1= tf.placeholder(tf.int32, [10], name="input") output1= tf.add(input1, tf.constant(100,dtype=tf.int32), name="output") # data depends on the input data saved_result= tf.Variable(data, name="saved_result") do_save=tf.assign(saved_result,output1) tf.initialize_all_variables() os.system("rm -rf /tmp/save_graph_logs") merged = tf.merge_all_summaries() train_writer = tf.train.SummaryWriter(FLAGS.summaries_dir, sess.graph) os.system("rm -rf /tmp/load") tf.train.write_graph(sess.graph_def, "/tmp/load", "test.pb", False) #proto # now set the data: result,_=sess.run([output1,do_save], {input1: data}) # calculate output1 and assign to 'saved_result' saver = tf.train.Saver(tf.all_variables()) saver.save(sess,"checkpoint.data")
模型图示
加载模型
with tf.Session() as persisted_sess: print("load graph") with gfile.FastGFile("/tmp/load/test.pb",'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) persisted_sess.graph.as_default() tf.import_graph_def(graph_def, name='') print("map variables") persisted_result = persisted_sess.graph.get_tensor_by_name("saved_result:0") tf.add_to_collection(tf.GraphKeys.VARIABLES,persisted_result) try: saver = tf.train.Saver(tf.all_variables()) # 'Saver' misnomer! Better: Persister! except:pass print("load data") saver.restore(persisted_sess, "checkpoint.data") # now OK print(persisted_result.eval()) print("DONE")
显示结果
相关文章推荐
- TensorFlow 的简单例子
- 用Python从零实现贝叶斯分类器的机器学习的教程
- My Machine Learning
- 机器学习---学习首页 3ff0
- Spark机器学习(一) -- Machine Learning Library (MLlib)
- bp神经网络及matlab实现
- 反向传播(Backpropagation)算法的数学原理
- 关于SVM的那点破事
- 也谈 机器学习到底有没有用 ?
- TensorFlow人工智能引擎入门教程之九 RNN/LSTM循环神经网络长短期记忆网络使用
- TensorFlow人工智能引擎入门教程之十 最强网络 RSNN深度残差网络 平均准确率96-99%
- TensorFlow人工智能入门教程之十一 最强网络DLSTM 双向长短期记忆网络(阿里小AI实现)
- TensorFlow人工智能引擎入门教程之十二 Caffe转换tensorflow并 跨平台调用
- TensorFlow人工智能入门教程之十三 RCNN 区域卷积网络(视频侦测分析人脸侦测区域检测 )
- TensorFlow人工智能入门教程之十四 自动编码机AutoEncoder 网络
- TensorFlow人工智能引擎入门教程所有目录
- Tensorflow 问题
- 人工智能扫盲漫谈篇 & 2018年1月新课资源推荐
- 人工智能唐宇迪老师专题团购~史无前例最低优惠~