您的位置:首页 > 其它

TensorFlow学习实践(四):使用TFRecord格式数据和tf.contrib.slim API进行模型训练和预测

2018-08-30 21:55 901 查看

本文以mnist为例,介绍如何使用TFRecord格式数据和tf.contrib.slim API进行模型训练和预测。

参考:

1、https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim

目录

一、数据输入

二、模型定义

三、模型训练

四、模型验证

一、数据输入

数据输入与上篇中的相同

TensorFlow学习实践(三):使用TFRecord格式数据和tf.estimator API进行模型训练和预测

二、模型定义

[code]def model_slim(images, labels, is_training):
net = slim.conv2d(images, 32, [5, 5], scope='conv1')
net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool1')
net = slim.conv2d(net, 64, [5, 5], scope='conv2')
net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool2')
net = slim.flatten(net, scope='flatten')
net = slim.fully_connected(net, 1024, scope='fully_connected1')
net = slim.dropout(net, keep_prob=0.6, is_training=is_training)
logits = slim.fully_connected(net, 10, activation_fn=None, scope='fully_connected2')

prob = slim.softmax(logits)
loss = slim.losses.sparse_softmax_cross_entropy(logits, labels)

global_step = tf.train.get_or_create_global_step()
num_batches_per_epoch = TRAIN_EXAMPLES_NUM / FLAGS.batch_size
decay_steps = int(num_batches_per_epoch * 10)

# Decay the learning rate exponentially based on the number of steps.
lr = tf.train.exponential_decay(learning_rate=0.001,
global_step=global_step,
decay_steps=decay_steps,
decay_rate=0.1,
staircase=True)

opt = tf.train.AdamOptimizer(learning_rate=lr)

return opt, loss, prob

slim定义模型比较简洁,此处模型简单,没有体现其优势,当网络层数较多时,slim的repeat和arg_scope很有用。

三、模型训练

[code]def train():
train_images, train_labels = mnist.input_fn(['./train_img.tfrecords'], True)

train_op, loss, pred = mnist.model_slim(train_images, train_labels, is_training=True)
train_tensor = slim.learning.create_train_op(loss, train_op)
result = slim.learning.train(train_tensor, FLAGS.train_dir, number_of_steps=FLAGS.max_step, log_every_n_steps=100)
print('final step loss: {}'.format(result))

训练过程很简单,打开tf.logging.set_verbosity(tf.logging.INFO),可以看到每隔log_every_n_steps,会有日志打印,输出loss值。结果如下:

[code]...
INFO:tensorflow:global step 9699: loss = 0.0084 (0.004 sec/step)
INFO:tensorflow:global step 9799: loss = 0.0000 (0.004 sec/step)
INFO:tensorflow:global step 9899: loss = 0.0001 (0.007 sec/step)
INFO:tensorflow:global step 9999: loss = 0.0002 (0.004 sec/step)
INFO:tensorflow:Stopping Training.
INFO:tensorflow:Finished training! Saving model to disk.
final step loss: 0.00017899183148983866

Process finished with exit code 0

用slim进行训练时,不方便在训练过程中对验证集进行验证,这个问题可以查看the TensorFlow issue,里面有人提出一个方法,可以试一下。

四、模型验证

[code]def validation():
validation_images, validation_labels = mnist.input_fn(['./validation_img.tfrecords'], False)
_, loss, pred = mnist.model_slim(validation_images, validation_labels, is_training=False)
prediction = tf.argmax(pred, axis=1)

# Choose the metrics to compute:
value_op, update_op = tf.metrics.accuracy(validation_labels, prediction)
num_batchs = math.ceil(mnist.VALIDATION_EXAMPLES_NUM / FLAGS.batch_size)

print('Running evaluation...')
# Only load latest checkpoint
checkpoint_path = tf.train.latest_checkpoint(FLAGS.train_dir)

metric_values = slim.evaluation.evaluate_once(
num_evals=num_batchs,
master='',
checkpoint_path=checkpoint_path,
logdir=FLAGS.train_dir,
eval_op=update_op,
final_op=value_op)
print(metric_values)

验证结果:

[code]INFO:tensorflow:Restoring parameters from ./train\model.ckpt-10000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [4/40]
INFO:tensorflow:Evaluation [8/40]
INFO:tensorflow:Evaluation [12/40]
INFO:tensorflow:Evaluation [16/40]
INFO:tensorflow:Evaluation [20/40]
INFO:tensorflow:Evaluation [24/40]
INFO:tensorflow:Evaluation [28/40]
INFO:tensorflow:Evaluation [32/40]
INFO:tensorflow:Evaluation [36/40]
INFO:tensorflow:Evaluation [40/40]
INFO:tensorflow:Finished evaluation at 2018-08-30-13:52:28
0.992

Process finished with exit code 0

slim有evaluate_once和evaluate_loop两种方法进行验证,evaluate_once验证一次后结束,evaluate_loop会一直循环,每隔一定时间验证一次。

这里有个地方和官方文档介绍的不一样,文档中定义metics时用slim.metrics.accuracy(predictions, labels),但是我查看代码发现,该方法只返回一个Accuracy `Tensor`,所以这里用了tf.metrics.accuracy,不知道是不是版本问题。

我觉得使用estimator进行训练和验证,比slim的API更方便一些,可以用slim定义模型,然后用estimator进行训练和验证,有机会尝试一下。

 

最后:完整代码

https://github.com/buptlj/learn_tf

 

阅读更多
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐