MNIST手写数据集识别TensorFlow代码实现。
2018-11-25 11:03
651 查看
版权声明:李阿刁 https://blog.csdn.net/qq_33475649/article/details/84478274
以下是MNIST手写数据集识别TensorFlow实现代码,其中加入正则化过程和指数衰减的学习率设置。参考书籍:Tensorflow 实战Google深度学习框架(第2版)
[code]from tensorflow.examples.tutorials.mnist import input_data #MINIST相关常数 INPUT_NODE=784 OUTPUT_NODE=10 #配置神经网络的参数 LAYER1_NODE=500 BATCH_SIZE=100 LEARNING_RATE_BASE=0.8 LEARNING_RATE_DECAY=0.99 REGULARIZATON_RATE=0.0001 TRAINING_STEPS=30000 MOVING_AVERAGE_DECAY=0.99 def inference(input_tensor,avg_class,weights1,biases1,weights2,biases2): if avg_class==None: layer1=tf.nn.relu(tf.matmul(input_tensor,weights1)+biases1) return tf.matmul(layer1,weights2)+biases2 else: layer1=tf.nn.relu(tf.matmul(input_tensor,avg_class.average(weights1)+avg_class.average(biases1))) return tf.matmul(layer1,avg_class.average(weights2))+avg_class.average(biases2) def train(mnist): x=tf.placeholder(tf.float32,[None,INPUT_NODE],name='x-input') y_=tf.placeholder(tf.float32,[None,OUTPUT_NODE],name='y-input') weights1=tf.Variable(tf.truncated_normal([INPUT_NODE,LAYER1_NODE],stddev=0.1)) biases1=tf.Variable(tf.constant(0.1,shape=[LAYER1_NODE])) weights2=tf.Variable(tf.truncated_normal([LAYER1_NODE,OUTPUT_NODE],stddev=0.1)) biases2=tf.Variable(tf.constant(0.1,shape=[OUTPUT_NODE])) y=inference(x,None,weights1,biases1,weights2,biases2) global_step=tf.Variable(0,trainable=False) variable_averages=tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,global_step) variables_averages_op=variable_averages.apply(tf.trainable_variables()) average_y=inference(x,variable_averages,weights1,biases1,weights2,biases2) cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,labels=tf.argmax(y_,1)) cross_entropy_mean=tf.reduce_mean(cross_entropy) regularizer=tf.contrib.layers.l2_regularizer(REGULARIZATON_RATE) regularization=regularizer(weights1)+regularizer(weights2) loss= cross_entropy_mean+regularization learning_rate=tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,mnist.train.num_examples,LEARNING_RATE_DECAY) train_step=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step=global_step) with tf.control_dependencies([train_step,variables_averages_op]): train_op=tf.no_op(name='train') correct_prediction=tf.equal(tf.argmax(average_y,1),tf.argmax(y_,1)) accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) with tf.Session() as sess: tf.global_variables_initializer().run() validate_feed={x:mnist.validation.images,y_:mnist.validation.labels} test_feed={x:mnist.test.images,y_:mnist.test.labels} for i in range(TRAINING_STEPS): if i%1000==0: validate_acc=sess.run(accuracy,feed_dict=validate_feed) print("After %d training step(s),validation accuracy is %g"%(i,validate_acc)) xs,ys=mnist.train.next_batch(BATCH_SIZE) sess.run(train_op,feed_dict={x:xs,y_:ys}) test_acc=sess.run(accuracy,feed_dict=test_feed) print("After %d training step(s),test accuarcy is %g"%(i,test_acc)) def main(argv=None): mnist=input_data.read_data_sets("/tmp/data",one_hot=True) train(mnist) if __name__=='__main__': tf.app.run()
相关文章推荐
- TensorFlow代码实现(一)[MNIST手写数字识别]
- 利用tensorflow一步一步实现基于MNIST 数据集进行手写数字识别的神经网络,逻辑回归
- Deep Learning-TensorFlow (1) CNN卷积神经网络_MNIST手写数字识别代码实现
- tensorflow 学习专栏(五):在mnist数据集上使用tensorflow实现临近算法(Nearest-Neighbor)进行手写数字识别
- Deep Learning-TensorFlow (1) CNN卷积神经网络_MNIST手写数字识别代码实现详解
- 基于TensorFlow的mnist数据集的最近邻算法实现代码
- Python神经网络代码识别手写字的实现流程(一):加载mnist数据
- 基于tensorflow和lenet-5模型实现mnist手写数字识别
- 用tensorflow实现MNIST(手写数字识别)
- TensorFlow手写数字MNIST识别,两层卷积神经网路(代码及代码注释)最后的准确率0.99
- tensorflow实现MNIST数据集识别---进一步理解
- tensorflow 学习笔记7 普通神经网络实现mnist手写识别
- tensorflow——用RNN实现MNIST手写数字识别
- tensorflow 学习笔记12 循环神经网络RNN LSTM结构实现MNIST手写识别
- Tensorflow 实现MNIST手写数字体识别
- 基于TensorFlow1.4.0的FNN全连接网络识别MNIST手写数据集
- TensorFlow之CNN实现MNIST手写数字识别
- Tensorflow 实现 MNIST 手写数字识别
- tensorflow 第一个程序MNIST手写数字识别(Softmax Regression实现)
- 深度学习-传统神经网络使用TensorFlow框架实现MNIST手写数字识别