TensorFlow实现Softmax Regressin识别手写数字
在深度学习领域的应用中,我们不得不提一个很主流的框架—TensorFlow,它是一个基于数据流编程的符号数学系统,目前,已经被广泛各类机器学习的算法的编程实现中,其前身是谷歌的神经网络算法库DistBelief。关于TensorFlow的更多知识介绍可以点击这里
下面我们开始正式介绍如何使用TensorFlow实现Softmax Regressin识别手写数字这个很基础的小案例。(注:本文主体内容为博主在学习《TensorFlow实战》后所总结凝练的内容,本文的案例实现是在Ubuntu系统中安装TensorFlow后,并在pycharm的IDE环境中实现的)
1. 数据集的获取
- 本次案例的实现采用MINST(Mixed National Institute of Standards and Technology database)数据集,它是一个非常简单且实用性很高的一个机器视觉数据集,里面包含了许多手写的数字的图片,而且这些手写的数字图片只包含灰度值信息,刚好复合本次案例的数据集的要求。
- 在安装TensorFlow后,可以在pycharm中新建一个python工程命名为HandwrittenDigtalRecongition,在该工程目录下,新建一个从c1.py文件作为源代码文件。下面开始编写代码实现导入MNIST数据集:
# 下载数据集 from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) # 使用one-hot编码 print(mnist.train.images.shape, mnist.train.labels.shape) print(mnist.test.images.shape, mnist.test.labels.shape) print(mnist.validation.images.shape, mnist.validation.labels.shape)
执行代码后的控制台输出如下:
/usr/bin/python3.6 /home/chenchi/PycharmProjects/HandwrittenDigtalRecongition/c1.py Extracting MNIST_data/train-images-idx3-ubyte.gz WARNING:tensorflow:From /home/chenchi/PycharmProjects/HandwrittenDigtalRecongition/c1.py:6: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version. Instructions for updating: Please use alternatives such as official/mnist/dataset.py from tensorflow/models. WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version. Instructions for updating: Please write your own downloading logic. WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version. Instructions for updating: Please use tf.data to implement this functionality. WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version. Extracting MNIST_data/train-labels-idx1-ubyte.gz Instructions for updating: Please use tf.data to implement this functionality. WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version. Instructions for updating: Please use tf.one_hot on tensors. WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version. Instructions for updating: Extracting MNIST_data/t10k-images-idx3-ubyte.gz Extracting MNIST_data/t10k-labels-idx1-ubyte.gz Please use alternatives such as official/mnist/dataset.py from tensorflow/models. (55000, 784) (55000, 10) (10000, 784) (10000, 10) (5000, 784) (5000, 10)
可以看到MINIST数据集已经成功下载及解压,同时我们还可以看到训练集有55000个样本,测试集有10000个样本,验证集有5000个样本。其中,每一个样本都有它所对应的标注信息。补充一句784=2828,这个是因为我们数据集的图像是28像素28像素大小的灰度图片。
2.one-hot编码
在机器学习中,我们常常会遇到需要分类的样本,这些样本并不是连续,而是离散的,无序的。通常我们需要将其进行数字特征化处理,便可以采取one-hot的编码方式来处理这样的样本数据集。one-hot编码,译为独热码,简单来说就是有多少个状态就有多少比特,而且只有一个比特为1,其他全为0的一种码制 。如在通讯网络协议中,使用8位或者16位状态的独热码,且系统占用其中一个状态码,余下的可以供用户使用。
再回到我们本次采用的样本数据集中,我们所有识别的为0-9个数字,10个种类,数据的label是一个10维的向量,只有一个位是1,其他均为0。比如数字2对应的label为【0,0,1,0,0,0,0,0,0】,数字7对应的label为【0,0,0,0,0,0,0,1,0,0】,数字n就代表对应的位置的值为1。
3.Softmax Regression算法实现
本案例采用Softmax Regression算法训练手写识别的分类模型,当我们处理多分类的任务时,便可以采取此类模型,softmax逻辑回归模型是logistic回归模型在多分类问题上的推广,在多分类问题中,类标签y可以取两个以上的值。了解更多可以点击这里
- 定义算法公式,也就是神经网络forward时的计算
- 定义loss,选择优化器,并指定优化器loss
- 迭代地对数据进行训练
- 在测试集上对准确率进行评测
此部分内容的详细解释可以参考《TensorFlow实战 黄文坚 唐源 著》
4.完整代码实现
# 下载本案例所需的数据集 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) # 使用one-hot编码 print(mnist.train.images.shape, mnist.train.labels.shape) # 输出数据集相关信息,包括样本集图片的数量,图像的像素等 print(mnist.test.images.shape, mnist.test.labels.shape) print(mnist.validation.images.shape, mnist.validation.labels.shape) sess = tf.InteractiveSession() # 第一步,定义算法公式 x = tf.placeholder(tf.float32, [None, 784]) # 构建占位符,None表示样本的数量可以是任意的 W = tf.Variable(tf.zeros([784, 10])) # 构建一个变量,代表权重矩阵,初始化为0 b = tf.Variable(tf.zeros([10])) # 构建一个变量,代表偏置,初始化为0 y = tf.nn.softmax(tf.matmul(x, W) + b) # 构建了一个softmax的模型:y = softmax(Wx + b),y指样本标签的预测值 # 第二步,定义损失函数,选定优化器,并指定优化器优化损失函数 y_ = tf.placeholder(tf.float32, [None, 10]) # 交叉熵损失函数 cross_entropy = -tf.reduce_sum(y_ * tf.log(y)) # 使用梯度下降法最小化cross_entropy损失函数 train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) # 第三步,迭代地对数据进行训练 tf.global_variables_initializer().run() for i in range(1000): # 迭代次数1000 batch_xs, batch_ys = mnist.train.next_batch(100) # 使用minibatch,一个batch大小为100 train_step.run({x: batch_xs, y_: batch_ys}) # 第四步,在测试集或验证集上对准确率进行评测 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) # tf.argmax()返回的是某一维度上其数据最大所在的索引值,在这里即代表预测值和真值 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # 用平均值来统计测试准确率 print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels})) # 打印测试信息
运行结果:
2019-05-10 12:46:34.939199: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: FMA 2019-05-10 12:46:37.266697: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2100000000 Hz 2019-05-10 12:46:37.333611: I tensorflow/compiler/xla/service/service.cc:150] XLA service 0x1f54120 executing computations on platform Host. Devices: 2019-05-10 12:46:37.333688: I tensorflow/compiler/xla/service/service.cc:158] StreamExecutor device (0): <undefined>, <undefined> WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version. Instructions for updating: Colocations handled automatically by placer. 2019-05-10 12:47:11.866633: W tensorflow/core/framework/allocator.cc:124] Allocation of 31360000 exceeds 10% of system memory. 0.9187 Process finished with exit code 0
可以看到预测精度为0.9187
- TensorFlow实现Softmax Regression识别手写数字
- TensorFlow实战 3.2实现softmax regression识别手写数字
- Tensorflow实现Softmax Regression识别手写数字
- TensorFlow 实现 Softmax Regression 识别手写数字
- TensorFlow实现Softmax Regression 识别手写数字(3.2节)
- Tensorflow实现Softmax Regressoin手写识别数字---TensorFlow实战3.2节
- Tensorflow实战学习(二十四)【实现Softmax Regression(回归)识别手写数字】
- TensorFlow(二)实现Softmax Regression 识别手写数字
- TensorFlow实现Softmax Regression识别手写数字中"TimeoutError: [WinError 10060] 由于连接方在一段时间后没有正确答复或连接的主机没有反应,连接尝试失败”问题
- Tensorflow实现MNIST手写数字识别(Softmax Regression)
- TensorFlow实现Softmax Regression识别手写数字
- tensorflow实现softmax regression识别手写数字
- Python实战 | TensorFlow之softmax的实现——手写数字识别
- 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字
- TensorFlow 实现Softmax Regression识别手写数字
- tensorflow实战(一)TensorFlow实现 softmax Regression 识别手写数字
- tensorflow 第一个程序MNIST手写数字识别(Softmax Regression实现)
- 【TensorFlow-windows】(一)实现Softmax Regression进行手写数字识别(mnist)
- 用TensorFlow的Softmax Regression进行手写数字识别
- Tensorflow的Helloword:使用简单Softmax Regression模型来识别Mnist手写数字