tensorflow模型保存与可视化
2017-10-22 12:14
405 查看
本例以数据的二分类为例,实现了模型的保存、加载、以及tensorboard的可视化。
1、实现功能
对如下数据进行二分类,[[1.,1.2],[2.,2.3],[3.,3.5],[4.,4.1],[1.,0.8],[2.,1.3],[3.,2.5],[4.,3.1]],如图所示。数据以Y=X为分界线,上部分是1类,下部分是0类。
2、具体代码
代码分为两部分,一个是lt_save.py,主要实现了模型的训练与保存,一个是lt_load.py主要实现模型的加载。2.1 lt_save.py
#coding:utf-8 from __future__ import division import tensorflow as tf import numpy as np import os X = np.array([[1.,1.2],[2.,2.3],[3.,3.5],[4.,4.1],[1.,0.8],[2.,1.3],[3.,2.5],[4.,3.1]]) Y = np.array([[0,1],[0,1],[0,1],[0,1],[1,0],[1,0],[1,0],[1,0]]) #-----------------------------前向过程----------------------- x_input = tf.placeholder(tf.float32,shape = (None,2),name = "x_input") y_input = tf.placeholder(tf.int16,shape = (None,2),name = "y_input") with tf.variable_scope("layer_1"): w1 = tf.Variable(tf.random_normal([ 4000 2,3],stddev=0.5)) b1 = tf.Variable(tf.random_normal([3,])) a1 = tf.nn.relu(tf.nn.bias_add(tf.matmul(x_input,w1),b1)) with tf.variable_scope("layer_2"): w2 = tf.Variable(tf.random_normal([3,4],stddev=1)) b2 = tf.Variable(tf.random_normal([4,])) a2 = tf.nn.relu(tf.nn.bias_add(tf.matmul(a1,w2),b2)) with tf.variable_scope("output"): w3 = tf.Variable(tf.random_normal([4,2],stddev=2)) b3 = tf.Variable(tf.random_normal([2,])) y = tf.nn.bias_add(tf.matmul(a2,w3),b3) y_prediction = tf.arg_max(y,1,name = "prediction") with tf.variable_scope("cost"): cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(y,tf.arg_max(y_input,1))) tf.scalar_summary("loss",cost) #写入日志文件 with tf.variable_scope("accuracy"): corrcet_prediction = tf.equal(tf.arg_max(y,1),tf.arg_max(y_input,1)) acc = tf.reduce_mean(tf.cast(corrcet_prediction,tf.float32)) #这里要将corrcet_prediction转换为浮点型 tf.scalar_summary("acc",acc) #写入日志文件 global_step = tf.Variable(0,trainable=False) train_step = tf.train.AdamOptimizer(0.005).minimize(cost, global_step=global_step) #定义最小化cost函数操作 #------------------------------------------------------------ init = tf.initialize_all_variables() #初始化变量操作 merged = tf.merge_all_summaries() #整理所有的日志文件 STEPS = 3001 SAVE_PATH = "./model/" MODEL_NAME = "lt_model.ckpt" SUMMARY_PATH = "./summary" batch_size = 4 data_size = len(X) print "\ndata_size",data_size saver = tf.train.Saver() with tf.Session() as sess: summary_writer = tf.train.SummaryWriter(SUMMARY_PATH,tf.get_default_graph()) sess.run(init) for i in range(STEPS): start = (i*batch_size)%data_size # print "\nstart",start end = min(start+batch_size,data_size) #这里start 与end 必须要满足data_size能被batch_size整除 #这里也可以用yield进行迭代供给数据 # print "\nend",end summary, _ = sess.run([merged,train_step],feed_dict={x_input:X[start:end],y_input:Y[start:end]}) #得到运行时的日志 summary_writer.add_summary(summary,i) #将所有日志写入文件 if i%100 == 0: loss,accuracy,currect_step = sess.run([cost,acc,global_step],feed_dict={x_input:X[start:end],y_input:Y[start:end]}) print "step=",currect_step,"loss=",loss,"acc=",accuracy if i%3000 == 0: saver.save(sess,os.path.join(SAVE_PATH,MODEL_NAME),global_step=currect_step) #保存模型 print "model has been saved" summary_writer.close() #关闭日志文件 print "all done" #查看可视化结果: #tensorboard --logdir=./summary
运行lt_save.py,保存了模型与日志文件。
2.2 查看日志文件
若想查看日志内容,如loss的变化情况,输入:tensorboard --logdir=./summary
然后会出现一个网址,点击,则可以进入tensorboard,查看保存日志,得到 loss, acc 的数据。如下图所示。
2.3 lt_load.py
这个代码实现了加载模型,并且对一个数据进行分类。#coding:utf-8 import tensorflow as tf import numpy as np X = np.array([[4.,4.5]]) #输入数据 graph = tf.Graph() with graph.as_default(): sess = tf.Session() with sess.as_default(): saver = tf.train.import_meta_graph("./model/lt_model.ckpt-3001.meta")#加载图,这里保存了整个图的结构 saver.restore(sess,"./model/lt_model.ckpt-3001")#加载模型,这里保存了每个变量的值 x_input = graph.get_operation_by_name("x_input").outputs[0]#加载placeholder,这里计算节点为"x_input",其本身没有:0 y_prediction = graph.get_operation_by_name("output/prediction").outputs[0]#从节点"output/prediction" 加载张量 y_prediction。注意.outputs[0]是指 #节点的第一个输出,即y_prediction pred = sess.run(y_prediction,feed_dict={x_input:X})#进行预测 print "prediction is ",pred
输出结果为:prediction is [1]
3 不足之处
保存模型时没有设置checkpoints,应该是训练一定步数保存一次,也没有实现自动加载最新的模型,接下来应当实现。相关文章推荐
- tensorflow的一些代码分析(五) tensorflow模型保存和可视化
- tensorflow笔记:模型的保存与训练过程可视化
- tensorflow笔记:模型的保存与训练过程可视化
- tensorflow笔记:模型的保存与训练过程可视化
- tensorflow笔记:模型的保存与训练过程可视化
- tensorflow保存模型,加载模型,tensorboard可视化
- 10 Tensorflow模型保存与读取
- tensorflow 模型的保存与恢复(Saver)
- tensorflow 1.0 学习:模型的保存与恢复(Saver)
- TensorFlow 模型保存与加载
- tensorflow 4——模型的保存、读取
- 转载:tensorflow保存训练后的模型
- tensorflow-模型保存和加载(二)
- TensorFlow保存和加载训练模型
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- 6.TensorFlow模型的保存和读取
- tensorflow(三) 模型保存
- TensorFlow保存和载入模型
- TensorFlow_MNIST 保存、恢复模型及参数
- TensorFlow模型保存和提取的方法