您的位置:首页 > Web前端

tensorflow:fully_connected_feed.py代码详细中文注释

2017-05-14 15:39 393 查看
"""Trains and Evaluates the MNIST network using a feed dictionary."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# pylint: disable=missing-docstring
import argparse
import os.path
import sys
import time

from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf

#导入tensorflow模块下的input_data.py文件以及mnist.py文件
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.examples.tutorials.mnist import mnist

#FLAGS用于存储模型的基本参数,比如训练数据存放的文件夹的位置等
FLAGS = None

def placeholder_inputs(batch_size):
"""Generate placeholder variables to represent the input tensors.
These placeholders are used as inputs by the rest of the model building
code and will be fed from the downloaded data in the .run() loop, below.
Args:
batch_size: The batch size will be baked into both placeholders.
Returns:
images_placeholder: Images placeholder.
labels_placeholder: Labels placeholder.
"""
# Note that the shapes of the placeholders match the shapes of the full
# image and label tensors, except the first dimension is now batch_size
# rather than the full size of the train or test data sets.
images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,
mnist.IMAGE_PIXELS))
labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))
return images_placeholder, labels_placeholder

def fill_feed_dict(data_set, images_pl, labels_pl):
"""Fills the feed_dict for training the given step.
A feed_dict takes the form of:
feed_dict = {
<placeholder>: <tensor of values to be passed for placeholder>,
....
}
Args:
data_set: The set of images and labels, from input_data.read_data_sets()
images_pl: The images placeholder, from placeholder_inputs().
labels_pl: The labels placeholder, from placeholder_inputs().
Returns:
feed_dict: The feed dictionary mapping from placeholders to values.
"""
#为占位符创建一个feed_dict,里面的内容是数据集中的下一个batch大小的数据
images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size,
FLAGS.fake_data)
feed_dict = {
images_pl: images_feed,
labels_pl: labels_feed,
}
return feed_dict

def do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_set):
"""Runs one evaluation against the full epoch of data.
Args:
sess: The session in which the model has been trained.
eval_correct: The Tensor that returns the number of correct predictions.
images_placeholder: The images placeholder.
labels_placeholder: The labels placeholder.
data_set: The set of images and labels to evaluate, from
input_data.read_data_sets().
"""

true_count = 0  # 正确预测结果的数量
steps_per_epoch = data_set.num_examples // FLAGS.batch_size#//为除法后结果四舍五入
num_examples = steps_per_epoch * FLAGS.batch_size
#对整个输入的数据集进行一次评价
for step in xrange(steps_per_epoch):
feed_dict = fill_feed_dict(data_set,
images_placeholder,
labels_placeholder)
true_count += sess.run(eval_correct, feed_dict=feed_dict)
precision = float(true_count) / num_examples#用预测正确的数量除以全部的数据量即为准确率
print('  Num examples: %d  Num correct: %d  Precision @ 1: %0.04f' %
(num_examples, true_count, precision))

def run_training():
"""Train MNIST for a number of steps."""
#获取数据集,包括了训练集、验证集以及测试集
data_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_data)

#告诉tensorflow模型将会被构建入默认的图中,因为第一步是构建图表
with tf.Graph().as_default():
#为图片和标签创建占位符
images_placeholder, labels_placeholder = placeholder_inputs(
FLAGS.batch_size)
#创建一个从推理模型
logits = mnist.inference(images_placeholder,
FLAGS.hidden1,
FLAGS.hidden2)

#在图表中加入计算损失函数的op操作
loss = mnist.loss(logits, labels_placeholder)

#在图表中加入使用梯度的op操作
train_op = mnist.training(loss, FLAGS.learning_rate)

#在图表中加入比较logits预测以及label的op操作,在调用do_eval函数中会用到
eval_correct = mnist.evaluation(logits, labels_placeholder)

# Build the summary Tensor based on the TF collection of Summaries.
summary = tf.summary.merge_all()

# 加入变量初始化的op
init = tf.global_variables_initializer()

#创建一个存储器来写入训练时候的检查点
saver = tf.train.Saver()

# 为了运行图中的op创建一个会话
sess = tf.Session()

# 实例化一个总结写入器
summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)

#运行初始化所有变量的op--initial
sess.run(init)

#开始循环训练
for step in xrange(FLAGS.max_steps):
start_time = time.time()

#使用fill_feed_dict函数获取图片和标签的字典,字典形式:
#feed_dict = {
#images_pl: images_feed,
#labels_pl: labels_feed,
#}
#该字典作为后面用来替代图片和标签占位符

feed_dict = fill_feed_dict(data_sets.train,
images_placeholder,
labels_placeholder)

#因为run里面有由两个op组成的列表,因此返回会是两个值,因为train_op没有返回值,所以我们只用到了loss_value即当前损失函数的值
_, loss_value = sess.run([train_op, loss],
feed_dict=feed_dict)

duration = time.time() - start_time

#每100次训练就输出得到的损失函数的数据
if step % 100 == 0:
print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))

summary_str = sess.run(summary, feed_dict=feed_dict)
summary_writer.add_summary(summary_str, step)
summary_writer.flush()

#保存检查点并且一定周期的评估训练得到的模型在训练集、验证集以及测试集上的性能
#每1000次训练就在整个数据集上进行一次模型的评估
if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
saver.save(sess, checkpoint_file, global_step=step)
#开始在训练集上评估模型
print('Training Data Eval:')
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.train)
#开始在验证集上评估模型
print('Validation Data Eval:')
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.validation)
#开始在测试集上评估模型
print('Test Data Eval:')
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.test)

def main(_):
if tf.gfile.Exists(FLAGS.log_dir):
tf.gfile.DeleteRecursively(FLAGS.log_dir)
tf.gfile.MakeDirs(FLAGS.log_dir)
run_training()

#如果是直接使用python fully_connected_feed.py的指令则从此处开始运行程序
if __name__ == '__main__':
#存储训练时候的一些参数
parser = argparse.ArgumentParser()
#此处是学习速率,如果将其变小,会使得每一个batch的训练loss改变很小,但是却很准确,读者可以试着将其改为0.005,其他不动,会发现准确率会下降,但如果将下一个参数max_step变大,会发现准确率会比原始的91的准确率高一些
parser.add_argument(
'--learning_rate',
type=float,
default=0.01,
help='Initial learning rate.'
)
#此处是表示要迭代训练多少次,即要用多少个batch来进行训练,一般这个参数越大会使得最后的准确率越高
parser.add_argument(
'--max_steps',
type=int,
default=2000,
help='Number of steps to run trainer.'
)
#此处是在hidden1层中的单元数量
parser.add_argument(
'--hidden1',
type=int,
default=128,
help='Number of units in hidden layer 1.'
)
#此处是在hidden2层中的单元数量
parser.add_argument(
'--hidden2',
type=int,
default=32,
help='Number of units in hidden layer 2.'
)
#此处是在批梯度下降法的训练中每一批的样本的数量,梯度下降法是每一次更新权值的时候使用了全部的训练集合,但这在数据量巨大的时候是低效率的,因此采用批梯度下降法,每一次用batch size个样本参与训练
parser.add_argument(
'--batch_size',
type=int,
default=100,
help='Batch size.  Must divide evenly into the dataset sizes.'
)
#此处是数据存放的目录
parser.add_argument(
'--input_data_dir',
type=str,
default='/tmp/tensorflow/mnist/input_data',
help='Directory to put the input data.'
)
#此处是日志文件存放的目录
parser.add_argument(
'--log_dir',
type=str,
default='/tmp/tensorflow/mnist/logs/fully_connected_feed',
help='Directory to put the log data.'
)
parser.add_argument(
'--fake_data',
default=False,
help='If true, uses fake data for unit testing.',
action='store_true'
)

FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: