您的位置:首页 > 其它

新手上手Tensorflow之手写数字识别应用(2)

2017-12-02 13:45 676 查看
本系列为应用TensorFlow实现手写数字识别应用的全过程的代码实现及细节讨论。按照实现流程,分为如下几部分:

1. 模型训练并保存模型

2. 通过鼠标输入数字并保存

2. 图像预处理

4. 读入模型对输入的图片进行识别

本文重点讨论模型的保存以及读入问题。

关于TensorFlow模型训练的部分,算法实现部分的论文、博客以及源码很多很多,相信大家也看了很多了,这里就不过多讨论。重点是,我们如何把我们训练的模型保存以及如何读入的问题。

训练完模型后,我们会得到模型参数的训练结果。如果我们想之后分享这个结果或者用来进行测试,就要保存这个结果了。TensorFlow提供了Saver类来保存和恢复(save/restore)模型参数。

1. Saver类保存和恢复参数的用法

首先来看一下官方给出的demo:

#Saving variables
# Create some variables.
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' #屏蔽乱七八糟的输出信息--强迫症患者。。。

import shutil
checkpoints_dir = './checkpoint1940/'
if os.path.exists(checkpoints_dir):
shutil.rmtree(checkpoints_dir)
os.makedirs(checkpoints_dir)
checkpoint_prefix = os.path.join(checkpoints_dir, 'model.ckpt')
import tensorflow as tf
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)

inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
inc_v1.op.run()
dec_v2.op.run()
# Save the variables to disk.
save_path = saver.save(sess, checkpoint_prefix)
print("Model saved in file: %s" % save_path)


运行结果:



import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' #屏蔽乱七八糟的输出信息--强迫症患者。。。

import shutil
checkpoints_dir = './checkpoint1940/'
#if os.path.exists(checkpoints_dir):
#    shutil.rmtree(checkpoints_dir)
#os.makedirs(checkpoints_dir)
checkpoint_prefix = os.path.join(checkpoints_dir, 'model.ckpt')
import tensorflow as tf
tf.reset_default_graph() #Clears the default graph stack and resets the global default graph.

# Create some variables.
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])
#Note that when you restore variables from a file you do not have to initialize them beforehand
# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, checkpoint_prefix)
print("Model restored.")
# Check the values of the variables
print("v1 : %s" % v1.eval())
print("v2 : %s" % v2.eval())


运行结果:



2. save/restore过程的技术细节

(1)checkpoint 文件

TensorFlow的Saver类是通过操作checkpoint文件来实现对变量(Variable)的存储和恢复。checkpoint文件是二进制的文件,存放着按照固定格式存储的“变量名-Tensor值”map对。一般来说,checkpoint文件有四种:



其中,checkpoint文件可以直接用记事本打开,里面存放的是最新模型的path和所有模型的path;

.meta stores the graph structure, .data stores the values of each variable in the graph, .index identifies the checkpiont. So in the example above: import_meta_graph uses the .meta, and saver.restore uses the .data and .index

(2)graph structure数据结构

当我们用默认的方式saver = tf.train.Saver()创建saver对象的时候,saver将持有graph里的所有的变量。那当我们分开save和restore的时候,就会出现集中情况:

restore时候的saver对象持有的variable是在save的时候的saver持有variable的一个子集:也就是训练时候的变量我们在测试的时候不一定都用,这时候我们就可以选取其子集创建使用,这种情况是没问题的;

restore时候的saver对象持有的variable在save的时候saver并没有持有。也就是说,我们在测试的时候定义了一个新的变量,这个变量在save的时候没有出现,那么这时候如果restore,因为保存的变量中没有这个新的变量,所以就会报错。例如,我们在上面的restore的python程序中,在v2变量下面加一个v3变量,v2 = tf.get_variable(“v2”, shape=[5]),运行一下,就会出现 NotFoundError 错误:

NotFoundError (see above for traceback): Key v3 not found in checkpoint [[Node: save/RestoreV2_2 = RestoreV2[dtypes=[DT_FLOAT],_device=”/job:localhost/replica:0/task:0/device:CPU:0”](_arg_save/Const_0_0, save/RestoreV2_2/tensor_names, save/RestoreV2_2/shape_and_slices)]]

(3)控制Saver的数据结构

根据上面的情况,我们只要确保restore时候的变量在save时都出现过就好了。但这样会给编程造成很大的不变。因为我们在测试的时候,很有可能会创建一些新的变量。针对这种情况,TensorFlow有两种方式可以解决:

- 创建Saver的时候,定义要保存的变量;这样我们在restore的时候,也一样定义要restore的变量,就好了;

v1 = tf.Variable(..., name='v1')
v2 = tf.Variable(..., name='v2')

# Pass the variables as a dict:
saver = tf.train.Saver({'v1': v1, 'v2': v2})

# Or pass them as a list.
saver = tf.train.Saver([v1, v2])
# Passing a list is equivalent to passing a dict with the variable op names
# as keys:
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})


但是有时候graph的结构比较复杂,要保存的变量很多,要一一对应还是很麻烦的。怎么办?采用tf.train.import_meta_graph()方法

#Create a saver.
saver = tf.train.Saver(...variables...)
#Remember the training_op we want to run by adding it to a collection.
tf.add_to_collection('train_op', train_op)
sess = tf.Session()
for step in xrange(1000000):
sess.run(train_op)
if step % 1000 == 0:
# Saves checkpoint, which by default also exports a meta_graph
# named 'my-model-global_step.meta'.
saver.save(sess, 'my-model', global_step=step)


在save的时候我们保存我们想要保存的变量,当然可以直接默认保存全部;在restore的时候,我们先导入保存的模型的数据结构,就可以了。

with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
new_saver.restore(sess, 'my-save-dir/my-model-10000')
# tf.get_collection() returns a list. In this example we only want the
# first one.
train_op = tf.get_collection('train_op')[0]
for step in xrange(1000000):
sess.run(train_op)


Reference

- TensorFlow, why there are 3 files after saving the model?:

https://stackoverflow.com/questions/41265035/tensorflow-why-there-are-3-files-after-saving-the-model
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  TensoFlow Save restore