您的位置:首页 > 其它

Tensorflow学习笔记:模型训练数据的保存和恢复的简单实例

2017-11-22 15:36 1011 查看
#! /usr/bin/env python2
# -*- coding: utf-8 -*-

import tensorflow as tf
import numpy as np
import argparse

'''
保存模型训练后参数的简单实例
'''
print('保存和恢复模型训练后参数的简单实例:')

#创建一个图
my_graph = tf.Graph()

with my_graph.as_default():
var = tf.Variable(0, name='counter') #一个变量,初始值设置为0,但是要会话执行run才会被赋值
#创建一个op, 实现var + 2
step = tf.constant(2)
newVar = tf.add(var, step)
update = tf.assign(var, newVar)
# 启动图后, 变量必须先经过`初始化` (init) op 初始化,
# 首先必须增加一个`初始化` op 到图中.
init_op = tf.initialize_all_variables()

#创建saver来保存模型数据
saver = tf.train.Saver()

#在会话中运行或测试图
def train_or_test(is_test):
#创建会话,启动图
with tf.Session(graph = my_graph) as sess:
if is_test == False: #如果是训练模型
print('Train begin...')
sess.run(init_op) #先运行初始化操作
print('var = %d' % sess.run(var)) #打印初始值

#更新var,并打印
for i in range(5):
sess.run(update)
print('[%d] var = %d' % (i, sess.run(var)))
#保存每次迭代的结果,保存的文件从val_iter-1开始而不是0,这个有点搞不明白,有知道原因的麻烦给个留言 哈哈
saver.save(sess, './model_data1/val_iter', global_step = i)

saver.save(sess, './model_data1/val_final')

else: #如果是测试模型
print('Test begin...')
for i in range(5)[1:]:
iter_data_file = './model_data1/val_iter-' + str(i)
#恢复每次迭代的结果
saver.restore(sess, iter_data_file)
print('[%d] var = %d' % (i, sess.run(var)))

model_data = tf.train.latest_checkpoint('./model_data1/')
print(model_data) # ./model_data1/val_final
saver.restore(sess, model_data)
print('read final var = %d' % sess.run(var))

#必须定义这个main入口
def main(_):
train_or_test(ARGS.test)

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'-t',
'--test',
#type = int,
default = False,
action = 'store_true', # 运行 ./model_train1.py -t或--test 则ARGS.test被置为True
help = 'train: True, test: False.'
)

#ARGS, unparsed = parser.parse_known_args()
#tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
ARGS = parser.parse_args()
print(ARGS)
tf.app.run()

'''
命令:(1) ./model_train1.py  训练模型
(2) ./model_train1.py -t[--test] 测试模型
'''

'''
保存完 model_data1目录下出现:
checkpoint      (具有最近检查点列表的协议缓冲区)
val_final       (包含变量的值)
val_final.meta  (包含图形结构)
val_iter-1
val_iter-1.meta
...
val_iter-4
val_iter-4.meta
'''
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  tensorflow