您的位置:首页 > 理论基础 > 计算机网络

深层神经网络2

2017-11-03 18:15 120 查看

使用验证集判断模型效果

为了评测神经网络模型在不同参数下的效果,一般会从训练集中抽取一部分作为验证数据。除了使用验证数据集,还可以采用

交叉验证(cross validation )
的方式验证模型效果,但是使用交叉验证会花费大量的时间。但在海量数据情况下,一般采用验证数据集的形式评测模型的效果。
一般采用的验证数据分布越接近测试数据分布,模型在验证数据上的表现越可以体现模型在测试数据上的保险。
使用滑动平均模型和指数衰减的学习率在一定程度上都是限制神经网络中参数更新的速度。
在处理复杂问题时,使用滑动平均模型、指数衰减的学习率和正则化损失可以明显提升模型的训练效果。

变量管理

Tensorflow提供了通过变量名称来创建或者获取一个变量的机制,避免了复杂神经网络频繁传递参数的情况。通过该机制,在不同的函数中可以直接通过变量的名字来使用变量,而不需要将变量通过参数的形式到处传递。
Tensorflow中通过变量名获取变量的机制主要通过

tf.get_variable()
tf.variable_scope()
函数实现。

  1. tf.get_variable()

    该函数创建变量的方法和
    tf.Variable()
    函数的用法基本一样,提供维度信息(
    shape
    )以及初始化方法(
    initializer
    )的参数。该函数的变量名称是一个必填参数,函数会根据这个名字去创建或者获取变量。当已经有同名参数时,会报错。
  2. tf.variable_scope()

    该函数可以控制
    tf.get_variable()
    函数的语义。当
    tf.variable_scope()
    函数使用参数
    reuse=True
    生成上下文管理器时,这个上下文管理器内所有的
    tf.get_variable()
    函数会直接获取已经创建的变量。如果不存在,则报错;当
    reuse=False
    或者
    reuse=None
    创建上下文管理器时,
    tf.get_variable()
    操作将创建新的变量,如果同名变量已经存在,则报错。
    同时
    tf.variable_scope()
    函数可以嵌套。新建一个嵌套的上下文管理器但不指定reuse,这时的reuse的取值和外面一层保持一致。当退出reuse设置为True的上下文之后reuse的值又回到了False(内层reuse不设置)。

同时,tf.variable_scope()函数生成的上下文管理器也会创建一个Tensorflow中的命名空间,在命名空间内创建的变量名称都会带上这个命名空间名作为前缀。可以直接通过带命名空间名称的变量名来获取其它命名空间下的变量(创建一个名称为空的命名空间,并设置为reuse=True)。

with tf.variable_scope(" ", reuse=True):
v5 = tf.get_variable("foo/bar/v", [1])
print(v5.name)
===>v:0   # 0表示variable这个运算输出的第一个结果

Tensorflow模型持久化

将训练得到的模型保存下来,可以方便下次直接使用(避免重新训练花费大量的时间)。Tensorflow提供的持久化机制可以将训练之后的模型保存到文件中。
Tensorflow提供了

tf.train.Saver
类来保存和还原神经网络模型。当保存模型之后,目录下一般会出现三个文件,这是因为Tensorflow会将计算图的结构和图上参数值分开保存。

  1. model.ckpy.meta
    文件,保存了Tensorflow计算图的结构。
  2. model.ckpt
    文件,保存了Tensorflow程序每一个变量的取值。
  3. checkpoint
    文件,保存了一个目录下所有的模型文件列表。

保存模型

saver = tf.train.Saver()

saver.save(sess, "path/model.ckpt")

加载模型,此时不用进行变量的初始化过程
saver.restore(sess, "path/model.ckpt")

sess.run(result)

为了保存和加载部分变量,在声明tf.train.Saver类时可以提供一个列表来指定需要保存或加载的变量,

saver = tf.train.Saver([v1])
。同时,tf.train.Saver类也支持在保存或者加载时给变量重命名,如果直接加载就会导致程序报变量找不到的错误,Tensorflow提供通过字典将模型保存时的变量名和要加载的变量联系起来。

v = tf.Variable(tf.constant(1.0, shape=[1]), name='other-v1')
saver = tf.train.Saver({"v1": v})
将原先变量名为v1的变量加载到变量v中,变量v的名称为other-v1。

这样做的目的时为了方便使用变量的滑动平均值。因为每一个变量的滑动平均值是通过影子变量维护的,如果在加载模型时直接将影子变量映射到变量自身,就不需要在调用函数来获取变量的平均值了。

为了方便加载重命名滑动平均变量,tf.train.ExponentialMovingAverage类提供了variables_to_restore()函数来生成tf.train.Saver类所需要的变量重命名字典。

v = tf.Variable(0)
ema = tf.train.ExponentialMovingAverage(0.99)
saver = tf.train.Saver(ema.variable_to_restore())
with tf.Session() as sess:
saver.restore(sess, "path/model.ckpt")
sess.run(v)

有时候不需要类似于变量初始化、模型保存等辅助节点的信息,Tensorflow提供了convert_variables_to_constants()函数将计算图中的变量及其取值通过常量的方式保存。

持久化原理及数据格式

Tensorflow程序中所有计算都会被表达为计算图上的节点。

MetaGraphDef

Tensorflow通过

元图(MetaGraph)
来记录计算图中节点的信息以及运行计算图中节点所需要的元数据,元图是由
MetaGraphDef Protocol Buffer
定义的,
MetaGraphDef
中的内容构成了Tensorflow持久化的第一个文件,也就是
model.ckpt.meta
文件。

  • meta_info_def
    属性,记录了Tensorflow计算图中的元数据以及Tensorflow程序中所有使用到的运算方法的信息。元数据包括了计算图的版本号以及用户指定的一些标签,其中
    meta_info_def
    属性的
    stripped_op_list
    属性保存了Tensorflow运算方法的信息,如果一个运算方法在计算图中出现了多次,在该字段中也只出现一次。
    stripped_op_list
    属性的类型是
    OpList
    OpList
    类型是一个
    OpDef
    类型的列表,该类型定义了一个运算的所有信息,包括运算名、输入输出和运算的参数信息。
  • graph_def
    属性,主要记录了Tensorflow计算图上的节点信息,Tensorflow计算图的每一个节点对应了Tensorflow程序中的一个运算。
    meta_info_def
    属性已经包含了所有运算的具体信息,所以
    graph_def
    属性只关注运算的连接结果。
    该属性是通过GraphDef Protocol Buffer定义的,GraphDef主要包含了一个
    NodeDef
    类型的列表,
    GraphDef
    versions
    属性存储了Tensorflow的版本号,
    node
    属性记录了所有的节点信息。
    node
    NodeDef
    类型,该类型的
    op
    属性给出了该节点使用的运算方法名称,具体信息可以通过
    meta_info_def
    获取,
    input
    属性是一个字符串列表,定义了运算的输入,
    device
    属性定义了处理该运算的设备,
    attr
    属性定义了和当前运算相关的配置信息。
  • saver_def
    属性,记录了持久化模型所需要用到的一些参数,比如保存到文件的文件名,保存操作和加载操作的名称以及保存频率、清理历史记录等。
    该属性主要通过
    SaverDef
    定义。
  • collention_def
    属性,Tensorflow计算图中可以维护不同的集合,底层实现就是通过
    collention_def
    这个属性。
    collection_def
    属性是一个从集合名称到集合内容的映射,其中集合名称为字符串,集合内容为
    CollentionDef Protocol Buffer
    。Tensorflow计算图上的集合主要可以维护4类不同的集合:
    NodeList
    用于维护计算图上的节点集合;
    BytesList
    用于维护字符串或者序列化之后的Protocol Buffer的集合;
    Int64List
    用于维护整数集合;
    FloatList
    用于维护实数集合。
SSTable

持久化Tensorflow中变量的取值,

tf.Save
r得到的
model.ckpt
文件保存了所有的变量,该文件使用
SSTable
格式存储的,相当于一个
(key, value)
列表。

CheckpointState

持久化的最后一个文件名叫

checkpoint
,这个文件是
tf.train.Saver
类自动生成且自动维护的。该文件中维护了一个由
tf.train.Saver
类持久化的所有Tensoflow模型文件的文件名,当某个模型文件被删除时,这个模型对应的文件名也会被移除,checkpoint中内容的格式为
CheckpointState Protocol Buffer

阅读更多
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: