您的位置:首页 > 编程语言

tensorflow编程基础

2018-09-04 20:21 323 查看

模型构建中的几个概念。张量:数据,即某一类型的多维数组。变量:常用于定义模型中的参数,是通过不断训练得到的值。占位符:输入变量的载体,也可以理解为定义函数时的参数。图中的节点操作(OP):即一个OP获得0个或者多个tensor,执行计算输出额外的0个或多个tensor。在python中,返回的tensor是numpy.ndarray对象。

在具体的项目中,会有三种应用场景,分别是训练场景、测试场景和使用场景。训练场景是实现模型从无到有的过程,通过对样本的学习训练,调整学习参数,形成最终的模型。测试场景和使用场景:测试场景是利用图的正向运算得到的结果与真实值进行比较的差别;使用场景也是利用图的正向运算得到结果,并直接使用。二者的运算过程是一样的。这个过程特别像普通编程中使用函数的过程:实参就是输入的样本,形参就是占位符,运算过程就相当于函数体,得到的结果相当于返回值。

session与图的交互过程中还定义了以下两种数据的流向机制。注入机制(feed):通过占位符向模式中传入数据;取回机制(fetch):从模式中得到结果。

演示注入机制:需要注意的是,feed只在调用它的方法内有效,方法结束后feed就会消失。代码如下所示:

[code]import tensorflow as tf
a=tf.placeholder(tf.int16)
b=tf.placeholder(tf.int16)
add=tf.add(a,b)
mul=tf.multiply(a,b)
with tf.Session() as sess:
print("相加:%i" % sess.run(add,feed_dict={a:3,b:4}))
print('相乘:%i' % sess.run(mul,feed_dict={a:3,b:5}))

保存和载入模型

[code]saver=tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
....
saver.save(sess,'save_patch+filename')

需要注意的是,save_patch的路径要在session的创建之前,模型保存在代码的同级目录下。而载入模型则通过在session中调用saver的restore()函数,从指定的路径找到对应名称的模型。除了在训练结束以后,在训练中也可以保存模型。这样当训练模型出现中断时,可以得到保存到的中间参数,习惯称之为保存检查点。

[code]import tensorflow as tf
import numpy as np
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"]='2'
train_X=np.linspace(-1,1,100)
train_Y=2*train_X+np.random.randn(100)*0.3
X=tf.placeholder('float')
Y=tf.placeholder('float')
W=tf.Variable(tf.random_normal([1]),name='weight')
b=tf.Variable(tf.zeros([1]),name='bias')
z=tf.multiply(X,W)+b
cost=tf.reduce_mean(tf.square(z-Y))
learning_rate=0.01
optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
initial=tf.global_variables_initializer()
training_epoch=20
display_step=2
saver=tf.train.Saver(max_to_keep=1)
savedir='log/'
with tf.Session() as sess:
sess.run(initial)
for epoch in range(training_epoch):
for (x,y) in zip(train_X,train_Y):
sess.run(optimizer,feed_dict={X:x,Y:y})
if epoch % display_step ==0:
loss=sess.run(cost,feed_dict={X:x,Y:y})
print('epoch:',epoch+1,'loss:',loss,'W:',sess.run(W),'b:',sess.run(b))
saver.save(sess,savedir+'linermodule.cpkt',global_step=epoch)
print('finished')
print('loss:',sess.run(cost,feed_dict={X:x,Y:y}),'w:',sess.run(W),'b:',sess.run(b))
load_epoch=18
with tf.Session() as sess2:
sess2.run(tf.global_variables_initializer())
saver.restore(sess2,savedir+'linermodule.cpkt-'+str(load_epoch))
print('x=0.2,z=',sess2.run(z,feed_dict={X:0.2}))

保存的检查点文件如下所示:

因为设置max_to_keep=1,所以在迭代的过程中只保存一个文件。在训练的过程中,新生成的模型会覆盖以前的模型。

运行结果如下图所示:

 

阅读更多
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: