tensorflow编程基础
模型构建中的几个概念。张量:数据,即某一类型的多维数组。变量:常用于定义模型中的参数,是通过不断训练得到的值。占位符:输入变量的载体,也可以理解为定义函数时的参数。图中的节点操作(OP):即一个OP获得0个或者多个tensor,执行计算输出额外的0个或多个tensor。在python中,返回的tensor是numpy.ndarray对象。
在具体的项目中,会有三种应用场景,分别是训练场景、测试场景和使用场景。训练场景是实现模型从无到有的过程,通过对样本的学习训练,调整学习参数,形成最终的模型。测试场景和使用场景:测试场景是利用图的正向运算得到的结果与真实值进行比较的差别;使用场景也是利用图的正向运算得到结果,并直接使用。二者的运算过程是一样的。这个过程特别像普通编程中使用函数的过程:实参就是输入的样本,形参就是占位符,运算过程就相当于函数体,得到的结果相当于返回值。
session与图的交互过程中还定义了以下两种数据的流向机制。注入机制(feed):通过占位符向模式中传入数据;取回机制(fetch):从模式中得到结果。
演示注入机制:需要注意的是,feed只在调用它的方法内有效,方法结束后feed就会消失。代码如下所示:
[code]import tensorflow as tf a=tf.placeholder(tf.int16) b=tf.placeholder(tf.int16) add=tf.add(a,b) mul=tf.multiply(a,b) with tf.Session() as sess: print("相加:%i" % sess.run(add,feed_dict={a:3,b:4})) print('相乘:%i' % sess.run(mul,feed_dict={a:3,b:5}))
保存和载入模型
[code]saver=tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) .... saver.save(sess,'save_patch+filename')
需要注意的是,save_patch的路径要在session的创建之前,模型保存在代码的同级目录下。而载入模型则通过在session中调用saver的restore()函数,从指定的路径找到对应名称的模型。除了在训练结束以后,在训练中也可以保存模型。这样当训练模型出现中断时,可以得到保存到的中间参数,习惯称之为保存检查点。
[code]import tensorflow as tf import numpy as np import os os.environ["TF_CPP_MIN_LOG_LEVEL"]='2' train_X=np.linspace(-1,1,100) train_Y=2*train_X+np.random.randn(100)*0.3 X=tf.placeholder('float') Y=tf.placeholder('float') W=tf.Variable(tf.random_normal([1]),name='weight') b=tf.Variable(tf.zeros([1]),name='bias') z=tf.multiply(X,W)+b cost=tf.reduce_mean(tf.square(z-Y)) learning_rate=0.01 optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) initial=tf.global_variables_initializer() training_epoch=20 display_step=2 saver=tf.train.Saver(max_to_keep=1) savedir='log/' with tf.Session() as sess: sess.run(initial) for epoch in range(training_epoch): for (x,y) in zip(train_X,train_Y): sess.run(optimizer,feed_dict={X:x,Y:y}) if epoch % display_step ==0: loss=sess.run(cost,feed_dict={X:x,Y:y}) print('epoch:',epoch+1,'loss:',loss,'W:',sess.run(W),'b:',sess.run(b)) saver.save(sess,savedir+'linermodule.cpkt',global_step=epoch) print('finished') print('loss:',sess.run(cost,feed_dict={X:x,Y:y}),'w:',sess.run(W),'b:',sess.run(b)) load_epoch=18 with tf.Session() as sess2: sess2.run(tf.global_variables_initializer()) saver.restore(sess2,savedir+'linermodule.cpkt-'+str(load_epoch)) print('x=0.2,z=',sess2.run(z,feed_dict={X:0.2}))
保存的检查点文件如下所示:
因为设置max_to_keep=1,所以在迭代的过程中只保存一个文件。在训练的过程中,新生成的模型会覆盖以前的模型。
运行结果如下图所示:
阅读更多
- 第一阶段-入门详细图文讲解tensorflow1.4 -(三)TensorFlow 编程基础知识
- DeepLearning | Tensorflow编程基础:Session、Constant、Variable、Tensor、Placeholder、OP
- 黑马程序员 自学06C#编程基础之循环(此文无for循环)
- Java Socket网络编程基础
- bash编程-Shell基础
- 嵌入式软件开发培训笔记——c编程基础
- C# Socket编程基础入门
- FreeCodeCamp日志-基础算法编程完成
- 计算机科学和编程导论-week1-编程基础
- Tensorflow 基础
- X Window编程基础 2
- 跟着姜少学Java基础编程之三:变量
- TensorFlow 编程概念
- 多线程编程基础(线程创建)
- C语言基础编程之进制转化
- 【Matlab】之 编程基础(一)
- Linux应用编程基础--(2)文件IO
- C#多线程编程实战(一):线程基础
- Shell脚本编程基础 四 更多的结构化命令
- Java网络编程基础(一)