Tensorflow模型的保存和加载
2018-03-01 18:30
519 查看
模型的保存
模型的加载
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 载入数据集 mnist = input_data.read_data_sets("MNIST_data", one_hot=True) # 每个批次的大小 batch_size = 100 # 计算一共有多少个批次 n_batch = mnist.train.num_examples // batch_size # 定义两个placeholder x = tf.placeholder(tf.float32, [None, 784]) y = tf.placeholder(tf.float32, [None, 10]) # 创建一个简单的神经网络 W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([1, 10])) prediction = tf.nn.softmax(tf.matmul(x, W) + b) # 定义二次代价函数 loss = tf.reduce_mean(tf.square(y - prediction)) # 交叉熵 # loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels= y, logits= prediction)) # 定义梯度下降法 train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) # 初始化变量 init = tf.global_variables_initializer() # 结果存储在一个布尔型列表中 correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(prediction, 1)) # 求准确率 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) with tf.Session() as sess: sess.run(init) for epoch in range(21): for batch in range(n_batch): batch_xs, batch_ys = mnist.train.next_batch(batch_size) sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys}) acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}) print("Iter" + str(epoch) + ", Testing Accuracy" + str(acc)) # 保存训练好的网络模型 saver = tf.train.Saver() saver.save(sess, 'net/my_net.ckpt')
模型的加载
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 载入数据集 mnist = input_data.read_data_sets("MNIST_data", one_hot=True) # 每个批次的大小 batch_size = 100 # 计算一共有多少个批次 n_batch = mnist.train.num_examples // batch_size # 定义两个placeholder x = tf.placeholder(tf.float32, [None, 784]) y = tf.placeholder(tf.float32, [None, 10]) # 创建一个简单的神经网络 W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([1, 10])) prediction = tf.nn.softmax(tf.matmul(x, W) + b) # 定义二次代价函数 loss = tf.reduce_mean(tf.square(y - prediction)) # 交叉熵 # loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels= y, logits= prediction)) # 定义梯度下降法 train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) # 初始化变量 init = tf.global_variables_initializer() # 结果存储在一个布尔型列表中 correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(prediction, 1)) # 求准确率 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) with tf.Session() as sess: sess.run(init) print(sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})) saver = tf.train.Saver() saver.restore(sess, 'net/my_net.ckpt') print(sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}))
相关文章推荐
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- tensorflow模型保存与加载
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- 解决tensorflow模型参数保存和加载的问题
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- tensorflow-模型保存和加载(一)
- tensorflow学习笔记六:保存和加载训练模型
- TensorFlow 模型保存与加载
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- Tensorflow加载预训练模型和保存模型
- tensorflow中保存模型、加载模型做预测(不需要再定义网络结构)
- TensorFlow保存和加载训练模型
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- TensorFlow构建卷积神经网络/模型保存与加载/正则化
- tensorflow 保存和加载模型 -2
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- tensorflow初探---模型文件保存和加载
- tensorflow 保存和加载模型Saver的使用
- [Tensorflow之九]Tensorflow模型的保存与加载