tensorflow训练mnist数据集-识别手写数字
1、加载mnist数据集
TensorFlow提供了mnist加载的封装,在python中直接运行以下代码即可完成mnist数据加载。
运行成功会出现如下现实:
下载好的mnist数据集在C:\Users\%你的主机名%文件夹下:
然后我们来查看一下这个数据集的情况,print(mnisr.trian.images.shape,mnist.train.labels.shape),该语句输出结果是:(55000,784) (55000,10)表示了训练集images数据一共有55000个样本,每一个样本长度是784,标签一共是10个数字。print(mnist.test.images.shape,minst.test.labels.shape),该语句输出的结果是(10000,784) (10000,10)表示测试集images数据一共10000个样本,每个样本长度是784,标签一共10个数字;print(mnist.validation.images.shape,mnist.validation.labels.shape),该语句输出的结果是(5000,784) (5000,10)表示image验证集有5000个样本,每个样本长度是784,标签一共10个数字。三种类型数据集的作用:在训练集上训练模型,利用验证集检验模型效果并决定何时完成训练,最后在测试机评价模型的效果。
>>> import tensorflow as tf
#载入TensorFlow库
>>> sess = tf.InteractiveSession()
#创建一个新的InteractiveSession,使用这个命令会将session注册为默认的session,之后的运算也默认泡在这个session里
#面,不同的session之间的数据和运算是相互的里的。
>>> x = tf.placeholder(tf.float32,[None,784])
#创建一个Placeholder,这个是输入数据的地方;Placeholder的第一个参数是数据类型,第二个参数【none,784】代表
#Tensor的shape,也是数据的尺寸,这里none代表不限条数的输入,784表示每条输入时一个784 维的向量。
>>> w = tf.Variable(tf.zeros([784,10]))>>> b = tf.Variable(tf.zeros([10]))
#给Softmax Regression模型中的weights和biases创建Variable对象(用来存储模型参数)我们把weights和biases全部初始化
#为0
>>> y = tf.nn.softmax(tf.matmul(x,w) + b)
#实现Softmax Regression模型算法
>>> y_ = tf.placeholder(tf.float32, [None,10])#定义一个Placeholder,输入真是的label,用来计算损失函数
>>> cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),reduction_indices=[1]))
#定义了一个损失函数cross-entropy
>>> train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
#这一步是训练,设置学习速率是0.5
>>> tf.global_variables_initializer().run()
#下一步使用TensorFlow的全局参数初始化tf.global_variables_initializer,然后调用run()方法。
>>> for i in range(1000):... batch_xs,batch_ys = mnist.train.next_batch(100)
... train_step.run({x:batch_xs, y_:batch_ys})
...
#迭代地进行训练操作train_step。
>>> correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
#完成训练后,对模型的准确率进行验证。
>>> accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
#统计全部预测的accuracy,这里需要先用tf.cast将之前输出结果bool转换为float32,再求平均值
>>> print(accuracy.eval({x:mnist.test.images,y_:mnist.test.labels}))0.9192
#最后输出结果,准确率大概在91%~92%之间。
>>>阅读更多
- 使用tensorflow利用神经网络分类识别MNIST手写数字数据集,转自随心1993
- 利用tensorflow一步一步实现基于MNIST 数据集进行手写数字识别的神经网络,逻辑回归
- 深度学习与TensorFlow实战(六)全连接网络基础—MNIST数据集输出手写数字识别准确率
- 将tensorflow训练好的模型移植到Android (MNIST手写数字识别)
- 训练Tensorflow识别手写数字 mnist
- tensorflow 学习专栏(五):在mnist数据集上使用tensorflow实现临近算法(Nearest-Neighbor)进行手写数字识别
- TensorFlow用MNIST训练的模型来识别手写数字
- caffe示例实现之4在MNIST手写数字数据集上训练与测试LeNet
- Keras_深度学习_MNIST数据集手写数字识别之各种调参
- tensorflow 第一个程序MNIST手写数字识别(Softmax Regression实现)
- tensorflow进行MNIST手写数字识别-LSTM
- [置顶] java实现基于Mnist数据集的手写数字识别
- tensorflow入门实践例子—MNIST手写数字识别
- Tensorflow的Helloword:使用简单Softmax Regression模型来识别Mnist手写数字
- 介绍保存与读取Keras模型的方法,并对MNIST数据集的训练模型尝试进行手写识别
- Tensorflow-mnist 手写数字识别
- TensorFlow在MNIST中的应用 识别手写数字(OpenCV+TensorFlow+CNN)
- TensorFlow学习笔记(3)--实现Softmax逻辑回归识别手写数字(MNIST数据集)
- TensorFlow之CNN实现MNIST手写数字识别
- tensorflow入门-mnist手写数字识别(一)