您的位置:首页 > 其它

tensorflow训练mnist数据集-识别手写数字

2018-04-04 21:02 411 查看

1、加载mnist数据集

    TensorFlow提供了mnist加载的封装,在python中直接运行以下代码即可完成mnist数据加载。


    运行成功会出现如下现实:


    下载好的mnist数据集在C:\Users\%你的主机名%文件夹下:


    然后我们来查看一下这个数据集的情况,print(mnisr.trian.images.shape,mnist.train.labels.shape),该语句输出结果是:(55000,784) (55000,10)表示了训练集images数据一共有55000个样本,每一个样本长度是784,标签一共是10个数字。print(mnist.test.images.shape,minst.test.labels.shape),该语句输出的结果是(10000,784) (10000,10)表示测试集images数据一共10000个样本,每个样本长度是784,标签一共10个数字;print(mnist.validation.images.shape,mnist.validation.labels.shape),该语句输出的结果是(5000,784) (5000,10)表示image验证集有5000个样本,每个样本长度是784,标签一共10个数字。三种类型数据集的作用:在训练集上训练模型,利用验证集检验模型效果并决定何时完成训练,最后在测试机评价模型的效果。


>>> import tensorflow as tf  

#载入TensorFlow库

>>> sess = tf.InteractiveSession()  

#创建一个新的InteractiveSession,使用这个命令会将session注册为默认的session,之后的运算也默认泡在这个session里

#面,不同的session之间的数据和运算是相互的里的。

>>> x = tf.placeholder(tf.float32,[None,784])

#创建一个Placeholder,这个是输入数据的地方;Placeholder的第一个参数是数据类型,第二个参数【none,784】代表

#Tensor的shape,也是数据的尺寸,这里none代表不限条数的输入,784表示每条输入时一个784 维的向量。

>>> w = tf.Variable(tf.zeros([784,10]))

>>> b = tf.Variable(tf.zeros([10]))

#给Softmax Regression模型中的weights和biases创建Variable对象(用来存储模型参数)我们把weights和biases全部初始化

#为0

>>> y = tf.nn.softmax(tf.matmul(x,w) + b)

#实现Softmax Regression模型算法

>>> y_ = tf.placeholder(tf.float32, [None,10])

#定义一个Placeholder,输入真是的label,用来计算损失函数

>>> cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),reduction_indices=[1]))

#定义了一个损失函数cross-entropy

>>> train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

#这一步是训练,设置学习速率是0.5

>>> tf.global_variables_initializer().run()

#下一步使用TensorFlow的全局参数初始化tf.global_variables_initializer,然后调用run()方法。

>>> for i in range(1000):
...     batch_xs,batch_ys = mnist.train.next_batch(100)
...     train_step.run({x:batch_xs, y_:batch_ys})

...

#迭代地进行训练操作train_step。

>>> correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))

#完成训练后,对模型的准确率进行验证。

>>> accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

#统计全部预测的accuracy,这里需要先用tf.cast将之前输出结果bool转换为float32,再求平均值

>>> print(accuracy.eval({x:mnist.test.images,y_:mnist.test.labels}))

0.9192

#最后输出结果,准确率大概在91%~92%之间。

>>>
阅读更多
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐