Tensorflow实战Google-第五章mnist数字识别
2017-10-03 23:59
369 查看
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import mnist_inference import os # mnist_train.py # 主训练类,用于训练逻辑 # 分片大小、基础学习率、衰减率、正则化率、训练次数、滑动平均衰减率 BATCH_SIZE = 100 LEARNING_RATE_BASE = 0.8 LEARNING_RATE_DECAY = 0.99 REGULARIZATION_RATE = 0.0001 TRAINING_STEPS = 30000 MOVING_AVERAGE_DECAY = 0.99 # 模型名和模型保存路径 MODEL_SAVE_PATH="MNIST_model/" MODEL_NAME="mnist_model" def train(mnist): # 定义样本输入输出数据 x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input') y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input') # 正则化,用于防止过拟合 # 当模型过于复杂之后,就会出现模型太过于匹配训练样本,而不能很好地适应测试样本, # 称这种现象为过拟合现象; # 正则化的思想是在损失函数中加入刻画模型复杂度的指标,这里成为正则化函数 # 这样整个损失函数就是J(wi) + r*R(wi);这里的wi是指所有的权重w, # 显然当w过于复杂时,R(wi)就会越大,制约损失函数的值,反过来限制R(wi)的值 # 参考链接:http://blog.csdn.net/jinping_shi/article/details/52433975 # REGULARIZATION_RATE是正则化权重 regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE) # 通过输入数据和正则化函数向前传播 y = mnist_inference.inference(x, regularizer) # 滑动平均模型:我的理解,为了防止权重等变量可能出现大的突升或者突降,我们使用了一个"缓兵之计", # 即,使得变量变化不要太大,这样模型将更加稳定健壮。 # variable2'=shadow_variable2=decay×shadow_variable1+(1−decay)×variable2 # 这里的shadow_variable1是变量variable的初始值,表示为v1;公式中的variable2为variable改变后得值,为v2 # decay为衰减率,variable2'-variable2为采用滑动后的-未采用的; # 显然,|variable2'-variable1|<|variable2-variable1| # 另外为了使模型在训练前期更新能够尽可能快,我们有队decay进行了函数处理;使得变化更大; # 比如变量X从0-10改变,decay=0.9; # 只单独采用滑动平均之后的值是0-1;在采用decay函数处理后,就是0-9,显然后面的变化更快一些 # 想法:其实真这样,感觉滑动平均就有点多此一举了 # decay=min{decay,1+num_updates / 10+num_updates} # 这里的num_updates参数就是下面的global_step,一般初始化为0,不可更改 global_step = tf.Variable(0, trainable=False) variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step) # 将定义的滑动平均variable_averages应用到所有的参数中,variables_averages_op即为参数的更新动作; # 即每执行一次variables_averages_op,就会更新一次全部的参数 variables_averages_op = variable_averages.apply(tf.trainable_variables()) # 损失函数:损失函数有很多,这里使用交叉熵作为损失函数,H(p,q)=−∑xp(x)log(q(x)) # 通俗的解释是交叉熵表示的是p和q之间的相似的,其中q是输出值所对应的概率,p是该样本的正确输出 # 另外,q的曲解需要使用到softmax函数,这个函数主要使用在多分类中,目的是将分类结果转化为所出现的概率q # 所有的概率和为1 # 函数argmax表示按行/列 取y_中的最大值;0-1表示列-行,问这里为什么还按行?不是只有一行吗? cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1)) # 求均值 cross_entropy_mean = tf.reduce_mean(cross_entropy) # 猜测:这里的tf.get_collection('losses'),可能是累计所有的正则化后的参数加到损失函数中 loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses')) # 定义学习率 # staircase:True表示整个样本训练一次才更新学习率;False表示,每训练每一步都更新学习率(一个batch_size) learning_rate = tf.train.exponential_decay( LEARNING_RATE_BASE, # 基础学习率,之后的在这个基础上递减 global_step, # 当前迭代轮数 mnist.train.num_examples / BATCH_SIZE, # 训练完所有的样本需要的迭代次数 LEARNING_RATE_DECAY, # 学习率衰减率 staircase=True) # 以learning_rate学习率梯度下降计算损失值loss,global_step为L2正则化参数 train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step) # control_dependenciesshi实现两个过程处理:反向传播更新参数train_step、滑动平均更新参数variables_averages_op # with指如果执行成功,就执行内部逻辑 with tf.control_dependencies([train_step, variables_averages_op]): train_op = tf.no_op(name='train') # 保存训练好的模型;即持久化模型 # 初始化持久化类 saver = tf.train.Saver() # 以下训练说明参考第三单元说明 with tf.Session() as sess: tf.global_variables_initializer().run() for i in range(TRAINING_STEPS): xs, ys = mnist.train.next_batch(BATCH_SIZE) _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys}) if i % 1000 == 0: print("After %d training step(s), loss on training batch is %g." % (step, loss_value)) saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step) def main(argv=None): mnist = input_data.read_data_sets("/MNIST_data", one_hot=True) train(mnist) if __name__ == '__main__': tf.app.run()
import tensorflow as tf # mnist_inference.py # 主网络构建,用于构建网络结构 INPUT_NODE = 784 OUTPUT_NODE = 10 LAYER1_NODE = 500 def get_weight_variable(shape, regularizer): weights = tf.get_variable("weights", shape, initializer=tf.truncated_normal_initializer(stddev=0.1)) if regularizer != None: tf.add_to_collection('losses', regularizer(weights)) return weights def inference(input_tensor, regularizer): with tf.variable_scope('layer1'): weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer) biases = tf.get_variable("biases", [LAYER1_NODE], initializer=tf.constant_initializer(0.0)) layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases) with tf.variable_scope('layer2'): weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer) biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.0)) layer2 = tf.matmul(layer1, weights) + biases return layer2
相关文章推荐
- (Tensorflow之八)MNIST数字识别源码--实战Google深度学习框架5.2小节
- tensorflow实战之四:MNIST手写数字识别的优化3-过拟合
- TensorFlow实战——MNIST数字识别问题
- TensorFlow实战-mnist手写数字识别(卷积神经网络)
- tensorflow实战之二:MNIST手写数字识别的优化1-代价函数优化
- Tensorflow项目实战一:MNIST手写数字识别
- 30分钟手把手带你入门TensorFlow——Mnist手写数字识别实战教程
- TensorFlow实战——CNN(LeNet5)——MNIST数字识别
- TensorFlow实战5:利用卷积神经网络对图像分类(初阶:MNIST手写数字)代码实现
- Tensorflow-MNIST数字识别练习代码
- TensorFlow - 手写数字识别 (MNIST), 多类分类 (multiclass classification) 问题
- tensorflow——用RNN实现MNIST手写数字识别
- 用MXnet入门实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别
- tensorflow 入门小例子(mnist手写数字识别)
- TensorFlow 深度学习框架(6)-- mnist 数字识别及不同模型效果比较
- TensorFlow 深度学习框架(6)-- mnist 数字识别及不同模型效果比较
- tensorflow实战(一)TensorFlow实现 softmax Regression 识别手写数字
- TensorFlow笔记之一:MNIST手写数字识别
- Tensorflow解决MNIST手写体数字识别
- 【TensorFlow-windows】(四) CNN(卷积神经网络)进行手写数字识别(mnist)