您的位置:首页 > 理论基础 > 计算机网络

TensorFlow深度学习进阶教程:TensorFlow实现CIFAR-10数据集测试的卷积神经网络

2017-10-11 19:50 951 查看
TensorFlow深度学习进阶教程:TensorFlow实现CIFAR-10数据集测试的卷积神经网络

     

     TensorFlow是一个非常强大的用来做大规模数值计算的库。其所擅长的任务之一就是实现以及训练深度神经网络。本教程使用的数据集是CIFAR-10,这是一个非常经典的数据集,包含60000张32×32的彩色图像,其中训练集50000张,测试集10000张。对CIFAR-10 数据集的分类是机器学习中一个公开的基准测试问题,其任务是对一组32x32RGB的图像进行分类,这些图像涵盖了10个类别:
飞机,
汽车, 鸟, 猫, 鹿, 狗, 青蛙, 马, 船以及卡车。




       想了解更多信息请参考CIFAR-10 page,以及Alex Krizhevsky写的技术报告。本教程适用于对Tensorflow有丰富经验的用户,并假定用户有机器学习相关领域的专业知识和经验。
      在本教程中,我们将学到构建一个TensorFlow CNN模型的基本步骤,并将通过这些步骤为CIFAR-10T构建一个深度卷积神经网络。这个教程假设你已经熟悉神经网络和CIFAR-10数据集。

       本教程的目标是建立一个用于识别图像的相对较小的卷积神经网络,在这一过程中,本教程会:
      (1)着重于建立一个规范的网络组织结构,训练并进行评估
      (2)为建立更大规模更加复杂的模型提供一个范例
      选择CIFAR-10是因为它的复杂程度足以用来检验TensorFlow中的大部分功能,并可将其扩展为更大的模型。与此同时由于模型较小所以训练速度很快,比较适合用来测试新的想法,检验新的技术。
        在这个卷积神经网络中,使用来一些新的技巧:
       (1)对weights进行来L2的正则化;

       (2)对图片进行翻转、随机剪切等数据增强,制造来更多的样本;

       (3)在每个卷积-最大池化层后面使用了LRN层,增强来模型的泛化能力。

        在开始编程前,我们需要在github上下载TensorFlow Models库,以便使用其中提供的CIFAR-10数据的类。

        本教程是先将CIFAR-10数据集下载下来,放到指定路径,然后直接在程序中加载,当然读者也可以参照程序中的另一种方法,直接从网上加载,本人倾向于事先下载好。

        在准备工作就绪后,就可以构建CIFAR-10测试集的卷积神经网络。以下代是本人根据自己的理解和参考编写而成,并加有注释,如有错误请指正。

# -*- coding: utf-8 -*-
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# 首先确保在github上已下载TensorFlow Models库
# 并将models/tutorials/image/cifar10文件中的.py文件添加到本项目中
# 载入常用库、读取CIFAR-10数据类
import cifar10, cifar10_input
import tensorflow as tf
import numpy as np
import time

# 定义batch_size、训练轮数max_steps,以及下载CIFAR-10数据的默认路径
max_steps = 3000
batch_size = 128
data_dir = '/home/guoguo16/DL_Projects/CIFAR10-CNN/cifar-10-batches-bin'

# 定义初始化weight函数
def variable_with_weight_loss(shape, stddev, w1):
var = tf.Variable(tf.truncated_normal(shape, stddev=stddev))
if w1 is not None:
weight_loss = tf.multiply(tf.nn.l2_loss(var), w1, name='weight_loss')
tf.add_to_collection('losses', weight_loss)
return var

# 使用cifar10类下载数据集,并解压、展开到其默认位置(若事先下载好数据集,此行代码无用)
#cifar10.maybe_download_and_extract()

# 产生训练需要的数据,包括特征及其对应的label
images_train, labels_train = cifar10_input.distorted_inputs(data_dir=data_dir,
batch_size=batch_size)

# 使用cifar10_input.inputs函数生成测试数据
images_test, labels_test = cifar10_input.inputs(eval_data=True,
data_dir=data_dir,
batch_size=batch_size)

# 创建输入数据的placeholder,包括特征和label
image_holder = tf.placeholder(tf.float32, [batch_size, 24, 24, 3])
label_holder = tf.placeholder(tf.int32, [batch_size])

# 创建conv1
weight1 = variable_with_weight_loss(shape=[5, 5, 3, 64], stddev=5e-2, w1=0.0)
kernel1 = tf.nn.conv2d(image_holder, weight1, [1, 1, 1, 1], padding='SAME')
bias1 = tf.Variable(tf.constant(0.0, shape=[64]))
conv1 = tf.nn.relu(tf.nn.bias_add(kernel1, bias1))
pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME')
norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001/9.0, beta=0.75)

# 创建conv2
weight2 = variable_with_weight_loss(shape=[5, 5, 64, 64], stddev=5e-2, w1=0.0)
kernel2 = tf.nn.conv2d(norm1, weight2, [1, 1, 1, 1], padding='SAME')
bias2 = tf.Variable(tf.constant(0.1, shape=[64]))
conv2 = tf.nn.relu(tf.nn.bias_add(kernel2, bias2))
norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001/9.0, beta=0.75)
pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME')

# 创建fc1
reshape = tf.reshape(pool2, [batch_size, -1])
dim = reshape.get_shape()[1].value
weight3 = variable_with_weight_loss(shape=[dim, 384], stddev=0.04, w1=0.004)
bias3 = tf.Variable(tf.constant(0.1, shape=[384]))
local3 = tf.nn.relu(tf.matmul(reshape, weight3) + bias3)

# 创建fc2
weight4 = variable_with_weight_loss(shape=[384, 192], stddev=0.04, w1=0.004)
bias4 = tf.Variable(tf.constant(0.1, shape=[192]))
local4 = tf.nn.relu(tf.matmul(local3, weight4) + bias4)

# 创建logits
weight5 = variable_with_weight_loss(shape=[192, 10], stddev=1/192, w1=0.0)
bias5 = tf.Variable(tf.constant(0.0, shape=[10]))
logits = tf.add(tf.matmul(local4, weight5), bias5)
# ................................................................
# 以上为整个网络的inference部分
# .................................................................

# 计算CNN的loss
def loss(logits, labels):
labels = tf.cast(labels, tf.int64)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=labels, name='cross_entropy_per_example')
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
tf.add_to_collection('losses', cross_entropy_mean)
return tf.add_n(tf.get_collection('losses'), name='total_loss')

# 将logits节点和label_holder传入loss函数获得最终loss
loss = loss(logits, label_holder)

# 优化器选择Adam Optimizer,学习率设为1e-3
train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)

# 用tf.nn.in_top_k函数求输出结果中top k的准确率
top_k_op = tf.nn.in_top_k(logits, label_holder, 1)

# 创建默认session,初始化全部模型参数
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()

# 启动图片数据增强线程队列
tf.train.start_queue_runners()

# 正式开始训练
for step in range(max_steps):
start_time = time.time()
image_batch, label_batch = sess.run([images_train, labels_train])
_, loss_value = sess.run([train_op, loss],
feed_dict={image_holder: image_batch,
label_holder: label_batch})
duration = time.time() - start_time

if step % 10 == 0:
examples_per_sec = batch_size / duration
sec_per_batch = float(duration)

format_str = ('step %d, loss = %.2f (%.1f examples/sec; %.3f sec/batch)')
print(format_str % (step, loss_value, examples_per_sec, sec_per_batch))

# 评测模型在测试集上的准确率
num_examples = 10000
import math
num_iter = int(math.ceil(num_examples / batch_size))
true_count = 0
total_sample_count = num_iter * batch_size
step = 0
while step < num_iter:
image_batch, label_batch = sess.run([images_test, labels_test])
predictions = sess.run([top_k_op], feed_dict={image_holder: image_batch,
label_holder: label_batch})
true_count += np.sum(predictions)
step += 1

# 将准确率的评测结果打印出来
precision = true_count / total_sample_count
print('precision @ 1 = %.3f' % precision)
         最终,采用程序中给出的参数,在CIFAR-10数据集上,通过一个短时间小迭代次数的训练,可以达到71%左右的准确率。持续增加max_steps,可以期望准确率逐渐增加。希望本教程给大家开了个头,使得在Tensorflow上可以为视觉相关工作建立更大型的Cnns模型。

       在后续工作中,我将继续为大家展现TensorFlow带来的无尽乐趣,我将和大家一起探讨深度学习的奥秘。当然,如果你感兴趣,我的Weibo将与你一起分享最前沿的人工智能、机器学习、深度学习与计算机视觉方面的技术。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
相关文章推荐