Tensorflow小样例-分类模型(识别mnist手写数字)
2017-08-02 10:22
459 查看
这个例子是用Tensorflow构建一个简单的两层神经网络,然后用于识别mnist手写数字:
结果如下:
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 导入mnist手写数字图像 mnist = input_data.read_data_sets('MNIST_data',one_hot=True) # 增加网络层的函数 def add_layer(inputs, int_size,out_size,activation_function=None): # 神经元权重初值:从正态分布中输出随机值,shape为[int_size,out_size] Weights = tf.Variable(tf.random_normal([int_size,out_size])) # 偏差初值 biases = tf.Variable(tf.zeros([1,out_size]) + 0.1) Wx_plus_b = tf.matmul(inputs, Weights) + biases # 激活函数 if activation_function is None: outputs = Wx_plus_b else: outputs = activation_function(Wx_plus_b) return outputs # 计算准确度函数 def compute_accuracy(v_xs,v_ys): # Python中有局部变量和全局变量,当局部变量名字和全局变量名字重复时,局部变量会覆盖掉全局变量。 # 如果要给全局变量在一个函数里赋值,必须使用global语句。global VarName的表达式会告诉Python, # VarName是一个全局变量,这样Python就不会在局部命名空间里寻找这个变量了。 global prediction # 执行添加输出层命令 y_pre = sess.run(prediction,feed_dict={xs:v_xs}) # 由于是用于分类手写数字,所以输出层中10个神经元只有1个有值,在预测是否准备时,只需判断输出的向量中非零的元素的位置是否相等即可 # tf.equal : 判断两个tensor是否每个元素都相等。返回一个格式为bool的tensor # tf.argmax : 找到给定的张量tensor中在指定轴axis上的最大值的位置。 correct_prediction = tf.equal(tf.argmax(y_pre,1),tf.argmax(v_ys,1)) # 求预测结果的平均值 # cast(x, dtype, name=None):将x的数据格式转化成dtype accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) result = sess.run(accuracy,feed_dict={xs:v_xs,ys:v_ys}) return result xs = tf.placeholder(tf.float32,[None, 784]) # 28*28 ys = tf.placeholder(tf.float32,[None, 10]) # 添加隐藏层,输入值是 xs,在隐藏层有 10 个神经元 prediction = add_layer(xs,784,10,activation_function=tf.nn.softmax) # 损失函数 cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys*tf.log(prediction),reduction_indices=[1])) # 定义优化器,目的是的损失函数尽可能小 train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 构建对话 sess = tf.Session() # 对所有变量进行初始化 sess.run(tf.initialize_all_variables()) for i in range(1000): # 按批次训练,每批100行数据 batch_xs,batch_ys = mnist.train.next_batch(100) # 开始训练 sess.run(train_step,feed_dict={xs:batch_xs,ys:batch_ys}) # 每50次输出一次准确度 if i % 50 == 0: print(compute_accuracy(mnist.test.images,mnist.test.labels))
结果如下:
相关文章推荐
- 使用tensorflow利用神经网络分类识别MNIST手写数字数据集,转自随心1993
- TensorFlow - 手写数字识别 (MNIST), 多类分类 (multiclass classification) 问题
- TensorFlow用MNIST训练的模型来识别手写数字
- Tensorflow的Helloword:使用简单Softmax Regression模型来识别Mnist手写数字
- TensorFlow之CNN实现MNIST手写数字识别
- tensorflow入门实践例子—MNIST手写数字识别
- tensorflow中mnist手写数字识别
- 【TensorFlow-windows】(三) 多层感知器进行手写数字识别(mnist)
- tensorflow实战之二:MNIST手写数字识别的优化1-代价函数优化
- 训练Tensorflow识别手写数字 mnist
- TensorFlow 深度学习框架(6)-- mnist 数字识别及不同模型效果比较
- TensorFlow实现mnist数字识别——CNN LeNet-5模型
- TensorFlow实战5:利用卷积神经网络对图像分类(初阶:MNIST手写数字)代码实现
- TensorFlow下进行MNIST手写数字识别实例,从最简单的两层到LeNet5
- 人工智能(1)用tensorflow识别MNIST手写数字数据集
- 基于tensorflow的MNIST手写数字识别--入门篇
- 基于Tensorflow的MNIST手写数字识别(一)
- Tensorflow之 CNN卷积神经网络的MNIST手写数字识别
- Tensorflow 实现 MNIST 手写数字识别
- 【TensorFlow-windows】(四) CNN(卷积神经网络)进行手写数字识别(mnist)