tensorflow:fully_connected_feed.py代码详细中文注释
2017-05-14 15:39
393 查看
"""Trains and Evaluates the MNIST network using a feed dictionary.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # pylint: disable=missing-docstring import argparse import os.path import sys import time from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf #导入tensorflow模块下的input_data.py文件以及mnist.py文件 from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import mnist #FLAGS用于存储模型的基本参数,比如训练数据存放的文件夹的位置等 FLAGS = None def placeholder_inputs(batch_size): """Generate placeholder variables to represent the input tensors. These placeholders are used as inputs by the rest of the model building code and will be fed from the downloaded data in the .run() loop, below. Args: batch_size: The batch size will be baked into both placeholders. Returns: images_placeholder: Images placeholder. labels_placeholder: Labels placeholder. """ # Note that the shapes of the placeholders match the shapes of the full # image and label tensors, except the first dimension is now batch_size # rather than the full size of the train or test data sets. images_placeholder = tf.placeholder(tf.float32, shape=(batch_size, mnist.IMAGE_PIXELS)) labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size)) return images_placeholder, labels_placeholder def fill_feed_dict(data_set, images_pl, labels_pl): """Fills the feed_dict for training the given step. A feed_dict takes the form of: feed_dict = { <placeholder>: <tensor of values to be passed for placeholder>, .... } Args: data_set: The set of images and labels, from input_data.read_data_sets() images_pl: The images placeholder, from placeholder_inputs(). labels_pl: The labels placeholder, from placeholder_inputs(). Returns: feed_dict: The feed dictionary mapping from placeholders to values. """ #为占位符创建一个feed_dict,里面的内容是数据集中的下一个batch大小的数据 images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size, FLAGS.fake_data) feed_dict = { images_pl: images_feed, labels_pl: labels_feed, } return feed_dict def do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_set): """Runs one evaluation against the full epoch of data. Args: sess: The session in which the model has been trained. eval_correct: The Tensor that returns the number of correct predictions. images_placeholder: The images placeholder. labels_placeholder: The labels placeholder. data_set: The set of images and labels to evaluate, from input_data.read_data_sets(). """ true_count = 0 # 正确预测结果的数量 steps_per_epoch = data_set.num_examples // FLAGS.batch_size#//为除法后结果四舍五入 num_examples = steps_per_epoch * FLAGS.batch_size #对整个输入的数据集进行一次评价 for step in xrange(steps_per_epoch): feed_dict = fill_feed_dict(data_set, images_placeholder, labels_placeholder) true_count += sess.run(eval_correct, feed_dict=feed_dict) precision = float(true_count) / num_examples#用预测正确的数量除以全部的数据量即为准确率 print(' Num examples: %d Num correct: %d Precision @ 1: %0.04f' % (num_examples, true_count, precision)) def run_training(): """Train MNIST for a number of steps.""" #获取数据集,包括了训练集、验证集以及测试集 data_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_data) #告诉tensorflow模型将会被构建入默认的图中,因为第一步是构建图表 with tf.Graph().as_default(): #为图片和标签创建占位符 images_placeholder, labels_placeholder = placeholder_inputs( FLAGS.batch_size) #创建一个从推理模型 logits = mnist.inference(images_placeholder, FLAGS.hidden1, FLAGS.hidden2) #在图表中加入计算损失函数的op操作 loss = mnist.loss(logits, labels_placeholder) #在图表中加入使用梯度的op操作 train_op = mnist.training(loss, FLAGS.learning_rate) #在图表中加入比较logits预测以及label的op操作,在调用do_eval函数中会用到 eval_correct = mnist.evaluation(logits, labels_placeholder) # Build the summary Tensor based on the TF collection of Summaries. summary = tf.summary.merge_all() # 加入变量初始化的op init = tf.global_variables_initializer() #创建一个存储器来写入训练时候的检查点 saver = tf.train.Saver() # 为了运行图中的op创建一个会话 sess = tf.Session() # 实例化一个总结写入器 summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph) #运行初始化所有变量的op--initial sess.run(init) #开始循环训练 for step in xrange(FLAGS.max_steps): start_time = time.time() #使用fill_feed_dict函数获取图片和标签的字典,字典形式: #feed_dict = { #images_pl: images_feed, #labels_pl: labels_feed, #} #该字典作为后面用来替代图片和标签占位符 feed_dict = fill_feed_dict(data_sets.train, images_placeholder, labels_placeholder) #因为run里面有由两个op组成的列表,因此返回会是两个值,因为train_op没有返回值,所以我们只用到了loss_value即当前损失函数的值 _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict) duration = time.time() - start_time #每100次训练就输出得到的损失函数的数据 if step % 100 == 0: print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration)) summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) summary_writer.flush() #保存检查点并且一定周期的评估训练得到的模型在训练集、验证集以及测试集上的性能 #每1000次训练就在整个数据集上进行一次模型的评估 if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps: checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt') saver.save(sess, checkpoint_file, global_step=step) #开始在训练集上评估模型 print('Training Data Eval:') do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.train) #开始在验证集上评估模型 print('Validation Data Eval:') do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.validation) #开始在测试集上评估模型 print('Test Data Eval:') do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.test) def main(_): if tf.gfile.Exists(FLAGS.log_dir): tf.gfile.DeleteRecursively(FLAGS.log_dir) tf.gfile.MakeDirs(FLAGS.log_dir) run_training() #如果是直接使用python fully_connected_feed.py的指令则从此处开始运行程序 if __name__ == '__main__': #存储训练时候的一些参数 parser = argparse.ArgumentParser() #此处是学习速率,如果将其变小,会使得每一个batch的训练loss改变很小,但是却很准确,读者可以试着将其改为0.005,其他不动,会发现准确率会下降,但如果将下一个参数max_step变大,会发现准确率会比原始的91的准确率高一些 parser.add_argument( '--learning_rate', type=float, default=0.01, help='Initial learning rate.' ) #此处是表示要迭代训练多少次,即要用多少个batch来进行训练,一般这个参数越大会使得最后的准确率越高 parser.add_argument( '--max_steps', type=int, default=2000, help='Number of steps to run trainer.' ) #此处是在hidden1层中的单元数量 parser.add_argument( '--hidden1', type=int, default=128, help='Number of units in hidden layer 1.' ) #此处是在hidden2层中的单元数量 parser.add_argument( '--hidden2', type=int, default=32, help='Number of units in hidden layer 2.' ) #此处是在批梯度下降法的训练中每一批的样本的数量,梯度下降法是每一次更新权值的时候使用了全部的训练集合,但这在数据量巨大的时候是低效率的,因此采用批梯度下降法,每一次用batch size个样本参与训练 parser.add_argument( '--batch_size', type=int, default=100, help='Batch size. Must divide evenly into the dataset sizes.' ) #此处是数据存放的目录 parser.add_argument( '--input_data_dir', type=str, default='/tmp/tensorflow/mnist/input_data', help='Directory to put the input data.' ) #此处是日志文件存放的目录 parser.add_argument( '--log_dir', type=str, default='/tmp/tensorflow/mnist/logs/fully_connected_feed', help='Directory to put the log data.' ) parser.add_argument( '--fake_data', default=False, help='If true, uses fake data for unit testing.', action='store_true' ) FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
相关文章推荐
- [07]tensorflow源码例子mnist源码——fully_connected_feed.py
- tensorflow源码例子mnist源码——fully_connected_feed.py
- Tensorflow:fully_connected_feed.py运行报错
- tensorflow学习fully_connected_feed.py
- TensorFlow入门 fully_connected_feed.py
- mnist的Tensorflow官方模板(fullly_connected_feed.py文件中参数解析问题)
- fully_connected_feed代码说明
- TensorFlow实现用于图像分类的卷积神经网络(代码详细注释)
- Tensorflow: fully_connected_feed.py运行报错
- Tensorboard打开方式学习笔记(基于mnist的fully_connected_feed.py模块)
- tensorflow 最小二乘拟合详细代码注释
- SqlHelper详细中文注释
- ASP.Net 2.0 窗体身份验证机制-转+自己代码注释示例与更详细的说明
- SqlHelper(带详细中文注释)
- 分页机制代码详细注释
- ICTCLAS 中科院分词系统 代码 注释 中文分词 词性标注
- linux 0.11源代码完全中文注释
- PHP 图片上传实现代码 带详细注释
- SqlHelper(带详细中文注释)
- SqlHelper(带详细中文注释)