[DL]3.基于CNN的手写数字识别
2017-08-28 12:06
274 查看
CNN的原理这里就不介绍了,如果想要了解详细的原理可以参考链接
本文注重的是CNN在TensorFlow中是如何实现的。CNN可以用图片表示为(仅供想象用,但不是本文使用的模型的图片表示):
下面结合代码具体解释。
最终经过大约半个小时的训练得到了99.1%的准确率。
本文注重的是CNN在TensorFlow中是如何实现的。CNN可以用图片表示为(仅供想象用,但不是本文使用的模型的图片表示):
下面结合代码具体解释。
from tensorflow.examples.tutorials.mnist import input_data print('数据加载...') mnist=input_data.read_data_sets('./data/mnist',one_hot=True) # 可以看到返回的是Datasets类型,包含了训练集、验证集、测试集 #return base.Datasets(train=train, validation=validation, test=test) print('图片表示示例:') print(mnist[0].images[0]) print('标签表示示例:') print(mnist[0].labels[0]) img_count_train=len(mnist[0].images) img_array_train=len(mnist[0].images[0]) img_label_train=len(mnist[0].labels[0]) print('训练集有%s张图片,每张图片表示为%s维数组,标签以one-hot方式编码为%s维数组。'%(img_count_train,img_array_train,img_label_train)) img_count_validation=len(mnist[1].images) img_array_validation=len(mnist[1].images[0]) img_label_validation=len(mnist[1].labels[0]) print('验证集有%s张图片,每张图片表示为%s维数组,标签以one-hot方式编码为%s维数组。'%(img_count_validation,img_array_validation,img_label_validation)) img_count_test=len(mnist[2].images) img_array_test=len(mnist[2].images[0]) img_label_test=len(mnist[2].labels[0]) print('测试集有%s张图片,每张图片表示为%s维数组,标签以one-hot方式编码为%s维数组。'%(img_count_test,img_array_test,img_label_test)) print('数据加载done...') print('-----------------------------------------------------------------------') #模型构建 #784*1->28*28*32->14*14*64->7*7*64*1024 #784表示的是一维的张量,代表的是一张图片中的像素组成,这里输入是784维,输出是28*28维 #28*28*32表示这一层的输入是28×28维,并且32个特征每个特征都做一次卷积。本层的输入是28*28*1维,输出是32个14*14维。 #14*14*64表示这一层的输入是14×14维,并且64个特征每个特征都做一次卷积。本层的输入是14*14*32维,输出是64个7*7维。 #7*7*64*1024表示这一层的输入是7*7*64维,并且与1024个神经元连接。本层的输入是7*7*64维,输出是1024维。 import tensorflow as tf def conv2d(x,W): return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME')#[1,1,1,1]=[batch, channels, height, width]. def max_pool_2x2(x): return tf.nn.max_pool(x,[1,2,2,1],[1,2,2,1],padding='SAME')#步长=2,窗口大小=2 #共享参数 def weigth_variables(shape): initial=tf.truncated_normal(shape,stddev=0.1) return tf.Variable(initial) #共享偏置值 def bias_variables(shape): initial=tf.constant(0.1,shape=shape) return tf.Variable(initial) y_ = tf.placeholder(tf.float32, shape=[None, 10]) #输出 x = tf.placeholder(tf.float32, shape=[None, 784]) #输入 x_image=tf.reshape(x,[-1,28,28,1]) #reshape之后的图片 [batch, in_height, in_width, in_channels] #第一层 weight1=weigth_variables([5,5,1,32]) #window=5*5 input=1 features=32 bias1=bias_variables([32]) out1=max_pool_2x2(tf.nn.relu(conv2d(x_image,weight1)+bias1)) #input:28*28*1-> output:32*14*14 #第二层 weight2=weigth_variables([5,5,32,64])#input:32*14*14->output:64*7*7 bias2=bias_variables([64]) out2=max_pool_2x2(tf.nn.relu(conv2d(out1,weight2)+bias2)) #densly layer weight3=weigth_variables([7*7*64,1024])#input:64*7*7->output:1024 bias3=bias_variables([1024]) input=tf.reshape(out2,[-1,7*7*64]) output=tf.nn.relu(tf.matmul(input,weight3)+bias3) #dropout keep_prob=tf.placeholder(tf.float32) tf.nn.dropout(output,keep_prob) #output weight4=weigth_variables([1024,10])#input:1024->output:10 bias4=bias_variables([10]) y=tf.matmul(output,weight4)+bias4 #测试 cross_entropy = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for i in range(20000): batch = mnist.train.next_batch(50) if i % 100 == 0: train_accuracy = accuracy.eval(feed_dict={ x: batch[0], y_: batch[1], keep_prob: 1.0}) print('step %d, training accuracy %g' % (i, train_accuracy)) train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) print('test accuracy %g' % accuracy.eval(feed_dict={ x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
最终经过大约半个小时的训练得到了99.1%的准确率。
相关文章推荐
- python tensorflow基于cnn实现手写数字识别
- [052]TensorFlow Layers指南:基于CNN的手写数字识别
- [Python]基于CNN的MNIST手写数字识别
- Tensorflow - Tutorial (4) :基于CNN的手写数字识别
- python tensorflow 基于cnn实现手写数字识别
- TensorFlow学习-基于CNN实现手写数字识别
- 基于Tensorflow的MNIST手写数字识别(一)
- 卷积神经网络CNN 手写数字识别
- 基于Keras搭建用于MNIST手写数字识别的CNN
- tensorflow进行MNIST手写数字识别-CNN
- 利用tensorflow一步一步实现基于MNIST 数据集进行手写数字识别的神经网络,逻辑回归
- 基于opencv的手写数字识别(MFC,HOG,SVM)
- 基于Keras搭建用于MNIST手写数字识别的CNN
- 基于tensorflow的MNIST手写数字识别(二)--入门篇
- Tensorflow手写数字识别之简单神经网络分类与CNN分类效果对比
- 基于Tensorflow的MNIST手写数字识别(三)
- TensorFlow学习笔记(3)----CNN识别MNIST手写数字
- CNN实现MNIST手写数字识别
- TensorFlow笔记(三)--CNN识别手写数字
- Android+TensorFlow+CNN+MNIST 手写数字识别实现