您的位置:首页 > 其它

tensorflow模型保存与可视化

2017-10-22 12:14 405 查看

本例以数据的二分类为例,实现了模型的保存、加载、以及tensorboard的可视化。

1、实现功能

对如下数据进行二分类,[[1.,1.2],[2.,2.3],[3.,3.5],[4.,4.1],[1.,0.8],[2.,1.3],[3.,2.5],[4.,3.1]],如图所示。



数据以Y=X为分界线,上部分是1类,下部分是0类。

2、具体代码

代码分为两部分,一个是lt_save.py,主要实现了模型的训练与保存,一个是lt_load.py主要实现模型的加载。

2.1 lt_save.py

#coding:utf-8
from __future__ import division

import tensorflow as tf
import numpy as np
import os

X = np.array([[1.,1.2],[2.,2.3],[3.,3.5],[4.,4.1],[1.,0.8],[2.,1.3],[3.,2.5],[4.,3.1]])
Y = np.array([[0,1],[0,1],[0,1],[0,1],[1,0],[1,0],[1,0],[1,0]])

#-----------------------------前向过程-----------------------
x_input = tf.placeholder(tf.float32,shape = (None,2),name = "x_input")
y_input = tf.placeholder(tf.int16,shape = (None,2),name = "y_input")

with tf.variable_scope("layer_1"):
w1 = tf.Variable(tf.random_normal([
4000
2,3],stddev=0.5))
b1 = tf.Variable(tf.random_normal([3,]))
a1 = tf.nn.relu(tf.nn.bias_add(tf.matmul(x_input,w1),b1))

with tf.variable_scope("layer_2"):
w2 = tf.Variable(tf.random_normal([3,4],stddev=1))
b2 = tf.Variable(tf.random_normal([4,]))
a2 = tf.nn.relu(tf.nn.bias_add(tf.matmul(a1,w2),b2))

with tf.variable_scope("output"):
w3 = tf.Variable(tf.random_normal([4,2],stddev=2))
b3 = tf.Variable(tf.random_normal([2,]))
y =  tf.nn.bias_add(tf.matmul(a2,w3),b3)
y_prediction = tf.arg_max(y,1,name = "prediction")

with tf.variable_scope("cost"):
cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(y,tf.arg_max(y_input,1)))
tf.scalar_summary("loss",cost)  #写入日志文件

with tf.variable_scope("accuracy"):
corrcet_prediction = tf.equal(tf.arg_max(y,1),tf.arg_max(y_input,1))
acc = tf.reduce_mean(tf.cast(corrcet_prediction,tf.float32))    #这里要将corrcet_prediction转换为浮点型
tf.scalar_summary("acc",acc)   #写入日志文件

global_step = tf.Variable(0,trainable=False)

train_step = tf.train.AdamOptimizer(0.005).minimize(cost, global_step=global_step) #定义最小化cost函数操作
#------------------------------------------------------------

init = tf.initialize_all_variables()  #初始化变量操作
merged = tf.merge_all_summaries()  #整理所有的日志文件

STEPS = 3001
SAVE_PATH = "./model/"
MODEL_NAME = "lt_model.ckpt"
SUMMARY_PATH = "./summary"
batch_size = 4
data_size = len(X)
print "\ndata_size",data_size
saver = tf.train.Saver()
with tf.Session() as sess:

summary_writer = tf.train.SummaryWriter(SUMMARY_PATH,tf.get_default_graph())

sess.run(init)

for i in range(STEPS):
start = (i*batch_size)%data_size
# print "\nstart",start
end = min(start+batch_size,data_size)    #这里start 与end 必须要满足data_size能被batch_size整除
#这里也可以用yield进行迭代供给数据
# print "\nend",end

summary, _ = sess.run([merged,train_step],feed_dict={x_input:X[start:end],y_input:Y[start:end]})  #得到运行时的日志
summary_writer.add_summary(summary,i)  #将所有日志写入文件
if i%100 == 0:
loss,accuracy,currect_step =  sess.run([cost,acc,global_step],feed_dict={x_input:X[start:end],y_input:Y[start:end]})
print "step=",currect_step,"loss=",loss,"acc=",accuracy
if i%3000 == 0:
saver.save(sess,os.path.join(SAVE_PATH,MODEL_NAME),global_step=currect_step)  #保存模型
print "model has been saved"

summary_writer.close()   #关闭日志文件
print "all done"
#查看可视化结果:
#tensorboard --logdir=./summary


运行lt_save.py,保存了模型与日志文件。

2.2 查看日志文件

若想查看日志内容,如loss的变化情况,输入:

tensorboard --logdir=./summary


然后会出现一个网址,点击,则可以进入tensorboard,查看保存日志,得到 loss, acc 的数据。如下图所示。







2.3 lt_load.py

这个代码实现了加载模型,并且对一个数据进行分类。

#coding:utf-8
import tensorflow as tf
import numpy as np

X = np.array([[4.,4.5]])  #输入数据

graph = tf.Graph()
with graph.as_default():
sess = tf.Session()
with sess.as_default():
saver = tf.train.import_meta_graph("./model/lt_model.ckpt-3001.meta")#加载图,这里保存了整个图的结构
saver.restore(sess,"./model/lt_model.ckpt-3001")#加载模型,这里保存了每个变量的值

x_input = graph.get_operation_by_name("x_input").outputs[0]#加载placeholder,这里计算节点为"x_input",其本身没有:0

y_prediction = graph.get_operation_by_name("output/prediction").outputs[0]#从节点"output/prediction" 加载张量 y_prediction。注意.outputs[0]是指
#节点的第一个输出,即y_prediction
pred = sess.run(y_prediction,feed_dict={x_input:X})#进行预测
print "prediction is ",pred


输出结果为:prediction is [1]

3 不足之处

保存模型时没有设置checkpoints,应该是训练一定步数保存一次,也没有实现自动加载最新的模型,接下来应当实现。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: