您的位置:首页 > 其它

TensorFlow实现Softmax Regressin识别手写数字

2019-05-12 14:25 423 查看
版权声明:本文为博主原创文章,未经博主允许不得转载。如需转载请与博主联系,并标注来源。 https://blog.csdn.net/weixin_41606064/article/details/90060698

在深度学习领域的应用中,我们不得不提一个很主流的框架—TensorFlow,它是一个基于数据流编程的符号数学系统,目前,已经被广泛各类机器学习的算法的编程实现中,其前身是谷歌的神经网络算法库DistBelief。关于TensorFlow的更多知识介绍可以点击这里
下面我们开始正式介绍如何使用TensorFlow实现Softmax Regressin识别手写数字这个很基础的小案例。(注:本文主体内容为博主在学习《TensorFlow实战》后所总结凝练的内容,本文的案例实现是在Ubuntu系统中安装TensorFlow后,并在pycharm的IDE环境中实现的)

1. 数据集的获取

  • 本次案例的实现采用MINST(Mixed National Institute of Standards and Technology database)数据集,它是一个非常简单且实用性很高的一个机器视觉数据集,里面包含了许多手写的数字的图片,而且这些手写的数字图片只包含灰度值信息,刚好复合本次案例的数据集的要求。
  • 在安装TensorFlow后,可以在pycharm中新建一个python工程命名为HandwrittenDigtalRecongition,在该工程目录下,新建一个从c1.py文件作为源代码文件。下面开始编写代码实现导入MNIST数据集:
# 下载数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)  # 使用one-hot编码
print(mnist.train.images.shape, mnist.train.labels.shape)
print(mnist.test.images.shape, mnist.test.labels.shape)
print(mnist.validation.images.shape, mnist.validation.labels.shape)

执行代码后的控制台输出如下:

/usr/bin/python3.6 /home/chenchi/PycharmProjects/HandwrittenDigtalRecongition/c1.py
Extracting MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From /home/chenchi/PycharmProjects/HandwrittenDigtalRecongition/c1.py:6: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
(55000, 784) (55000, 10)
(10000, 784) (10000, 10)
(5000, 784) (5000, 10)

可以看到MINIST数据集已经成功下载及解压,同时我们还可以看到训练集有55000个样本,测试集有10000个样本,验证集有5000个样本。其中,每一个样本都有它所对应的标注信息。补充一句784=2828,这个是因为我们数据集的图像是28像素28像素大小的灰度图片。

2.one-hot编码

在机器学习中,我们常常会遇到需要分类的样本,这些样本并不是连续,而是离散的,无序的。通常我们需要将其进行数字特征化处理,便可以采取one-hot的编码方式来处理这样的样本数据集。one-hot编码,译为独热码,简单来说就是有多少个状态就有多少比特,而且只有一个比特为1,其他全为0的一种码制 。如在通讯网络协议中,使用8位或者16位状态的独热码,且系统占用其中一个状态码,余下的可以供用户使用。
再回到我们本次采用的样本数据集中,我们所有识别的为0-9个数字,10个种类,数据的label是一个10维的向量,只有一个位是1,其他均为0。比如数字2对应的label为【0,0,1,0,0,0,0,0,0】,数字7对应的label为【0,0,0,0,0,0,0,1,0,0】,数字n就代表对应的位置的值为1。

3.Softmax Regression算法实现

本案例采用Softmax Regression算法训练手写识别的分类模型,当我们处理多分类的任务时,便可以采取此类模型,softmax逻辑回归模型是logistic回归模型在多分类问题上的推广,在多分类问题中,类标签y可以取两个以上的值。了解更多可以点击这里

  • 定义算法公式,也就是神经网络forward时的计算
  • 定义loss,选择优化器,并指定优化器loss
  • 迭代地对数据进行训练
  • 在测试集上对准确率进行评测
    此部分内容的详细解释可以参考《TensorFlow实战 黄文坚 唐源 著》

4.完整代码实现

# 下载本案例所需的数据集
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)  # 使用one-hot编码
print(mnist.train.images.shape, mnist.train.labels.shape) # 输出数据集相关信息,包括样本集图片的数量,图像的像素等
print(mnist.test.images.shape, mnist.test.labels.shape)
print(mnist.validation.images.shape, mnist.validation.labels.shape)

sess = tf.InteractiveSession()
# 第一步,定义算法公式
x = tf.placeholder(tf.float32, [None, 784])  # 构建占位符,None表示样本的数量可以是任意的
W = tf.Variable(tf.zeros([784, 10]))  # 构建一个变量,代表权重矩阵,初始化为0
b = tf.Variable(tf.zeros([10]))  # 构建一个变量,代表偏置,初始化为0
y = tf.nn.softmax(tf.matmul(x, W) + b)  # 构建了一个softmax的模型:y = softmax(Wx + b),y指样本标签的预测值

# 第二步,定义损失函数,选定优化器,并指定优化器优化损失函数
y_ = tf.placeholder(tf.float32, [None, 10])
# 交叉熵损失函数
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
# 使用梯度下降法最小化cross_entropy损失函数
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

# 第三步,迭代地对数据进行训练
tf.global_variables_initializer().run()
for i in range(1000):  # 迭代次数1000
batch_xs, batch_ys = mnist.train.next_batch(100)  # 使用minibatch,一个batch大小为100
train_step.run({x: batch_xs, y_: batch_ys})

# 第四步,在测试集或验证集上对准确率进行评测
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))  # tf.argmax()返回的是某一维度上其数据最大所在的索引值,在这里即代表预测值和真值
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))  # 用平均值来统计测试准确率
print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))  # 打印测试信息

运行结果:

2019-05-10 12:46:34.939199: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: FMA
2019-05-10 12:46:37.266697: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2100000000 Hz
2019-05-10 12:46:37.333611: I tensorflow/compiler/xla/service/service.cc:150] XLA service 0x1f54120 executing computations on platform Host. Devices:
2019-05-10 12:46:37.333688: I tensorflow/compiler/xla/service/service.cc:158]   StreamExecutor device (0): <undefined>, <undefined>
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
2019-05-10 12:47:11.866633: W tensorflow/core/framework/allocator.cc:124] Allocation of 31360000 exceeds 10% of system memory.
0.9187
Process finished with exit code 0

可以看到预测精度为0.9187

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐