TensorFlow 训练 MNIST 数据
2017-08-16 11:23
302 查看
参考:
http://blog.csdn.net/willduan1/article/details/52024254
这个是用一个简单的 Softmax 回归来进行训练的代码
有一些注意的地方就是,当在一定程度上增加训练次数的时候,正确率会有一点的提高。还有一点就是取的batch大的时候,进行BP传播的时候收敛速度会快一些,但是相应的训练的时间消耗的就会增加。
http://blog.csdn.net/willduan1/article/details/52024254
这个是用一个简单的 Softmax 回归来进行训练的代码
from tensorflow.examples.tutorials.mnist import input_data import tensorflow as tf #MNIST数据输入 mnist = input_data.read_data_sets("../../datasets/MNIST_data/", one_hot=True) x = tf.placeholder(tf.float32, [None, 784]) #图像输入向量 W = tf.Variable(tf.zeros([784, 10])) #权重,初始化值为全零 b = tf.Variable(tf.zeros([10])) #偏置,初始化值为全零 #进行模型计算,y是预测,y_ 是实际 y = tf.nn.softmax(tf.matmul(x, W) + b) y_ = tf.placeholder("float", [None, 10]) #计算交叉熵 cross_entropy = -tf.reduce_sum(y_ * tf.log(y)) #接下来使用BP算法来进行微调,以0.01的学习速率 train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) #上面设置好了模型,添加初始化创建变量的操作 init = tf.global_variables_initializer() #启动创建的模型,并初始化变量 sess = tf.Session() sess.run(init) #开始训练模型,循环训练1000次 for i in range(1000): #随机抓取训练数据中的100个批处理数据点 batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) ''''' 进行模型评估 ''' #判断预测标签和实际标签是否匹配 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) #计算所学习到的模型在测试数据集上面的正确率 print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
输出:0.9169
有一些注意的地方就是,当在一定程度上增加训练次数的时候,正确率会有一点的提高。还有一点就是取的batch大的时候,进行BP传播的时候收敛速度会快一些,但是相应的训练的时间消耗的就会增加。
相关文章推荐
- [置顶] TensorFlow 训练 MNIST 数据(二)
- tensorflow 训练mnist数据
- [置顶] TensorFlow 入门之训练 MNIST 数据
- Tensorflow深度学习入门——采用卷积和池化优化训练MNIST数据——代码+注释
- Tensorflow中mnist数据使用CNN训练
- tensorflow实现AlexNet训练mnist数据
- TensorFlow个人学习(训练 MNIST 数据 )
- Tensorflow训练mnist数据(完整版)
- 使用TensorFlow训练神经网络识别MNIST数据代码
- TensorFlow——训练自己的数据(四)模型测试
- tensorflow之MNIST手写字符集训练可视化
- tensorflow中mnist 使用cnn模型训练的输出层数为7x7的原因
- mxnet利用下载好的mnist数据训练cnn
- TensorFlow学习笔记-组合训练数据
- 使用Tensorflow训练自己的数据
- tensorflow 分布式 数据并行 异步训练 between-graph 自己写的实例 CNN
- MNIST数据集的卷积神经网络训练代码具体实现示例--Tensorflow 框架
- Tensorflow训练Kitti道路分割数据
- Tensorflow教程-MNIST 数据下载
- TensorFlow安装与入门: 使用CNN训练MNIST