TensorFlow-mnist
2017-08-30 16:34
253 查看
训练代码:
测试代码:
测试结果:
from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data flags = tf.app.flags FLAGS = flags.FLAGS flags.DEFINE_string('data_dir', '/tmp/data/', 'Directory for storing data') print(FLAGS.data_dir) mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) input=tf.placeholder(tf.float32,[None,784],name='input') label=tf.placeholder(tf.float32,[None,10],name='label') keep_prob=tf.placeholder(tf.float32,name='keep_prob') image=tf.reshape(input,[-1,28,28,1]) conv1_W=tf.Variable(tf.truncated_normal([5,5,1,32],stddev=0.1)) conv1_b=tf.Variable(tf.constant(0.1,shape=[32])) layer1=tf.nn.elu(tf.nn.conv2d(image,conv1_W,strides=[1,1,1,1],padding='SAME')+conv1_b) layer2=tf.nn.max_pool(layer1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME') conv2_W=tf.Variable(tf.truncated_normal([5,5,32,64],stddev=0.1)) conv2_b=tf.Variable(tf.constant(0.1,shape=[64])) layer3=tf.nn.elu(tf.nn.conv2d(layer2,conv2_W,strides=[1,1,1,1],padding='SAME')+conv2_b) layer4=tf.nn.max_pool(layer3,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME') layer5=tf.reshape(layer4,[-1,7*7*64]) fc1_W=tf.Variable(tf.truncated_normal([7*7*64,1024],stddev=0.1)) fc1_b=tf.Variable(tf.constant(0.1,shape=[1024])) layer5=tf.reshape(layer4,[-1,7*7*64]) layer6=tf.nn.elu(tf.matmul(layer5,fc1_W)+fc1_b) layer7=tf.nn.dropout(layer6,keep_prob) fc2_W=tf.Variable(tf.truncated_normal([1024,10],stddev=0.1)) fc2_b=tf.Variable(tf.constant(0.1,shape=[10])) output=tf.nn.softmax(tf.matmul(layer7,fc2_W)+fc2_b,name='output') cross_entropy=tf.reduce_mean(-tf.reduce_sum(label*tf.log(output),reduction_indices=[1])) train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) correct_predition=tf.equal(tf.argmax(output,1),tf.arg_max(label,1)) accuracy=tf.reduce_mean(tf.cast(correct_predition,tf.float32),name='accuracy') sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) batch = mnist.train.next_batch(50) for i in range(20000): batch = mnist.train.next_batch(50) train_step.run(feed_dict={input: batch[0], label: batch[1], keep_prob: 0.5}) if i%100==0: train_accuracy = accuracy.eval(feed_dict={input:batch[0], label:batch[1], keep_prob: 1.0}) print("%d:training accuracy %g"%(i,train_accuracy)) saver = tf.train.Saver() save_path = saver.save(sess,"E:/dnn/model")
测试代码:
from __future__ import division import numpy as np import tensorflow as tf from PIL import Image img = Image.open('E:/dnn/test.bmp').convert('L') if img.size[0] != 28 or img.size[1] != 28: img = img.resize((28, 28)) arr = [] for i in range(28): for j in range(28): pixel = 1.0 - float(img.getpixel((j, i)))/255.0 arr.append(pixel) image = np.array(arr).reshape((1, 28, 28, 1)) saver = tf.train.import_meta_graph('E:/dnn/model.meta') graph = tf.get_default_graph() input=graph.get_tensor_by_name('input:0') label=graph.get_tensor_by_name('label:0') output=graph.get_tensor_by_name('output:0') keep_prob=graph.get_tensor_by_name('keep_prob:0') accuracy=graph.get_tensor_by_name('accuracy:0') with tf.Session() as sess: saver.restore(sess, tf.train.latest_checkpoint('E:/dnn')) test = sess.run(output, feed_dict={input: image.reshape(-1,784), label: np.full(10,1e-10).reshape(-1,10), keep_prob: 1.0}) print(test) ans=0 for i in range(10): if (test[0][i]>test[0][ans]): ans=i print(ans)
测试结果:
![](https://images2017.cnblogs.com/blog/293751/201708/293751-20170830163326296-1167551388.png)
相关文章推荐
- TensorFlow教程06:MNIST的CNN实现——源码和运行结果
- TensorFlow利用普通神经网络识别MNIST以及tensorboard可视化
- tensorflow用Softmax Regression识别MNIST手写数字识别
- tensorflow学习笔记之使用tensorflow进行MNIST分类(1)
- TensorFlow官方教程学习笔记之3-用于机器学习专家学习的MNIST数据集(MNIST For ML Beginners)
- [Tensorflow] MNIST数字识别问题
- TensorFlow 教程 - 深入MNIST完整代码
- 【TensorFlow】MNIST(使用CNN)
- TensorFlow学习笔记(十四)TensorFLow 用mnist数据做classification
- TensorFlow 卷积神经网络之MNIST 手写数字识别
- 【TensorFlow】官方MNIST数据集神经网络实例详解(六)
- TensorFlow技术解析与实战 9 TensorFlow在MNIST中的应用
- tensorflow-002-MNIST
- 使用tensorFlow完成对MNIST数据集的训练
- 使用TensorFlow训练MNIST数据集,SystemExit异常的解决方案
- tensorflow使用RNN分析mnist手写体数字数据集
- Tensorflow中的mnist例子
- TensorFlow框架之使用卷积神经网络优化MNIST
- tensorflow实战之四:MNIST手写数字识别的优化3-过拟合