tensorflow学习——tf.train.Supervisor()与tf.train.saver()
2017-09-12 17:53
555 查看
1、tf.train.Supervisor()
从上面代码可以看出,Supervisor帮助我们处理一些事情
(1)自动去checkpoint加载数据或初始化数据
(2)自身有一个Saver,可以用来保存checkpoint
(3)有一个summary_computed用来保存Summary
所以,我们就不需要:
(1)手动初始化或从checkpoint中加载数据
(2)不需要创建Saver,使用sv内部的就可以
(3)不需要创建summary writer
2、tf.train.Saver()
import tensorflow as tf import numpy as np import os log_path = 'ckptdir/' log_name = 'liner.ckpt' x_data = np.random.rand(100).astype(np.float32) y_data = x_data*0.1 + 0.3 w = tf.Variable(tf.random_normal([1])) b = tf.Variable(tf.zeros([1])) y = w*x_data + b loss = tf.reduce_mean(tf.square(y-y_data)) train = tf.train.AdamOptimizer(0.5).minimize(loss) tf.summary.scalar('loss', loss) saver = tf.train.Saver() init = tf.global_variables_initializer() merged = tf.summary.merge_all() sv = tf.train.Supervisor(logdir=log_path, init_op=init) # logdir用来保存checkpoint和summary saver = sv.saver # 创建saver with sv.managed_session() as sess: # 会自动去logdir中去找checkpoint,如果没有的话,自动执行初始化 # sess.run(init) # if len(os.listdir(log_path)) != 0: # saver.restore(sess, os.path.join(log_path, log_name)) for step in range(201): sess.run(train) if step%50 == 0: print(step, sess.run(w), sess.run(b)) merged_summary = sess.run(merged) sv.summary_computed(sess, merged_summary,global_step=step) saver.save(sess, os.path.join(log_path, 'liner.ckpt'))
从上面代码可以看出,Supervisor帮助我们处理一些事情
(1)自动去checkpoint加载数据或初始化数据
(2)自身有一个Saver,可以用来保存checkpoint
(3)有一个summary_computed用来保存Summary
所以,我们就不需要:
(1)手动初始化或从checkpoint中加载数据
(2)不需要创建Saver,使用sv内部的就可以
(3)不需要创建summary writer
2、tf.train.Saver()
import tensorflow as tf import numpy as np import os log_path = 'ckptdir' log_name = 'liner.ckpt' x_data = np.random.rand(100).astype(np.float32) y_data = x_data*0.1 + 0.3 w = tf.Variable(tf.random_normal([1])) b = tf.Variable(tf.zeros([1])) y = w*x_data + b loss = tf.reduce_mean(tf.square(y-y_data)) train = tf.train.AdamOptimizer(0.5).minimize(loss) tf.summary.scalar('loss', loss) saver = tf.train.Saver() init = tf.global_variables_initializer() merged = tf.summary.merge_all() with tf.Session() as sess: sess.run(init) print("loading model from checkpoint") checkpoint = tf.train.latest_checkpoint(os.path.join(log_path, log_name)) restore_saver.restore(sess, checkpoint) #if len(os.listdir(log_path)) != 0: # saver.restore(sess, os.path.join(log_path, log_name)) for step in range(201): sess.run(train) if step%50 ==0: print(step, sess.run(w), sess.run(b)) summary_writer = tf.summary.FileWriter(log_path, sess.graph) summary_all = sess.run(merged) summary_writer.add_summary(summary_all) summary_writer.close() saver.save(sess, os.path.join(log_path, 'liner.ckpt'))
相关文章推荐
- tensorflow学习day2简单监督学习模型及用tf.train.Saver实现检查点恢复
- 【TensorFlow】模型持久化tf.train.Saver—下(九)
- tensorflow 1.0之tf.train.Saver 文档翻译
- tensorflow模型持久化之tf.train.saver
- tensorflow学习——tf.floor与tf.train.batch
- tensorflow中关于模型存储和恢复(tf.train.Saver())的问题
- TensorFlow入门使用 tf.train.Saver()保存模型
- TensorFlow入门(九)使用 tf.train.Saver()保存模型
- tensorflow 下的滑动平均模型 —— tf.train.ExponentialMovingAverage
- 【深度学习】TensorFlow的TFRecord存储
- [TensorFlow 学习笔记-04]卷积函数之tf.nn.conv2d
- TensorFlow学习笔记----TF生成数据的方法
- Tensorflow学习之tfrecords_reader
- 学习笔记TF049:TensorFlow 模型存储加载、队列线程、加载数据、自定义操作
- tf.train.Saver
- tensorboard学习——tf.train.SummaryWriter无此属性
- 3. Tensorflow学习笔记之tf.placeholder函数
- tensorflow学习——tfreader格式,队列读取数据tf.train.shuffle_batch()
- 深度学习笔记——深度学习框架TensorFlow(四)[高级API tf.contrib.learn]
- 学习TensorFlow之tf.placeholder()