TensorFlow - 手写数字识别 (模型的存储与加载)
2018-02-11 20:35
489 查看
TensorFlow - 手写数字识别 (模型的存储与加载)
flyfish
目的:解决训练过程的中断,再次训练从上次训练之后结果接着训练
而不是从头开始训练
环境Win10 Python3.6
Start from: 0
step 0, training accuracy 0.08
step 10, training accuracy 0.12
step 20, training accuracy 0.22
step 30, training accuracy 0.2
step 40, training accuracy 0.24
Process finished with exit code 1
中断再次启动之后又接着上次开始训练
Start from: 42
step 50, training accuracy 0.52
模型存储目录
E:\MyWork\venv\ckpt_dir
flyfish
目的:解决训练过程的中断,再次训练从上次训练之后结果接着训练
而不是从头开始训练
环境Win10 Python3.6
import tensorflow as tf import tensorflow.examples.tutorials.mnist.input_data as input_data import os mnist = input_data.read_data_sets('MNIST_data', one_hot=True) x = tf.placeholder("float", shape=[None, 784]) y_ = tf.placeholder("float", shape=[None, 10]) W = tf.Variable(tf.zeros([784,10])) b = tf.Variable(tf.zeros([10])) def weight_variable(shape): initial = tf.truncated_normal(shape, stddev=0.1) return tf.Variable(initial) def bias_variable(shape): initial = tf.constant(0.1, shape=shape) return tf.Variable(initial) #权重初始化 def conv2d(x, W): return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') def max_pool_2x2(x): return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') #第一层卷积 W_conv1 = weight_variable([5, 5, 1, 32]) b_conv1 = bias_variable([32]) x_image = tf.reshape(x, [-1,28,28,1]) h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) h_pool1 = max_pool_2x2(h_conv1) #d第二层卷积 W_conv2 = weight_variable([5, 5, 32, 64]) b_conv2 = bias_variable([64]) h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) h_pool2 = max_pool_2x2(h_conv2) #全连接层 W_fc1 = weight_variable([7 * 7 * 64, 1024]) b_fc1 = bias_variable([1024]) h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) keep_prob = tf.placeholder("float") h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) #输出层 W_fc2 = weight_variable([1024, 10]) b_fc2 = bias_variable([10]) #训练和评估模型 y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv)) train_step = tf.train.AdagradOptimizer(1e-4).minimize(cross_entropy) correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) ckpt_dir = "./ckpt_dir" if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) #标志变量不参与到训练中 global_step = tf.Variable(0, name='global_step', trainable=False) saver = tf.train.Saver() with tf.Session() as sess: ckpt = tf.train.get_checkpoint_state(ckpt_dir) if ckpt and ckpt.model_checkpoint_path: print(ckpt.model_checkpoint_path) saver.restore(sess, ckpt.model_checkpoint_path) # restore all variables else: tf.global_variables_initializer().run() start = global_step.eval() # get last global_step print("Start from:", start) for i in range(start, 200):#这里原来是20000 接着从上次start的地方训练 batch = mnist.train.next_batch(50) if i%10 == 0: train_accuracy = accuracy.eval(session=sess,feed_dict={ x:batch[0], y_: batch[1], keep_prob: 1.0}) print ("step %d, training accuracy %g"%(i, train_accuracy)) train_step.run(session=sess,feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) global_step.assign(i).eval() #i更新global_step. saver.save(sess, ckpt_dir + "/model.ckpt", global_step=global_step) print ("test accuracy %g"%accuracy.eval(session=sess,feed_dict={ x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
Start from: 0
step 0, training accuracy 0.08
step 10, training accuracy 0.12
step 20, training accuracy 0.22
step 30, training accuracy 0.2
step 40, training accuracy 0.24
Process finished with exit code 1
中断再次启动之后又接着上次开始训练
Start from: 42
step 50, training accuracy 0.52
模型存储目录
E:\MyWork\venv\ckpt_dir
相关文章推荐
- TensorFlow - 手写数字识别 (模型训练完成后的使用)
- Tensorflow小样例-分类模型(识别mnist手写数字)
- 使用tensorflow基于lenet-5模型识别手写数字
- TensorFlow用MNIST训练的模型来识别手写数字
- Tensorflow的Helloword:使用简单Softmax Regression模型来识别Mnist手写数字
- TensorFlow 深度学习框架(6)-- mnist 数字识别及不同模型效果比较
- 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字
- TensorFlow 实现 Softmax Regression 识别手写数字
- 5.1 Tensorflow:图与模型的加载与存储
- tensorFlow识别手写数字
- 用TensorFlow构造CNN进行手写数字识别
- tensorflow 手写数字识别
- TensorFlow学习-基于CNN实现手写数字识别
- 机器学习三(tensorflow 训练识别手写数字)
- 深度学习-传统神经网络使用TensorFlow框架实现MNIST手写数字识别
- tensorflow进行MNIST手写数字识别-CNN
- tensorflow识别手写数字
- TensorFlow实战(一)手写数字识别
- 机器学习Tensorflow基于MNIST数据集识别自己的手写数字(读取和测试自己的模型)
- 学习笔记CB009:人工神经网络模型、手写数字识别、多层卷积网络、word2vec