您的位置:首页 > 大数据 > 人工智能

tf.train.XXX与train有关的函数

2018-03-19 14:59 225 查看

tf.train.XXX与train有关的函数

tf.train.get_or_create_global_step()

这个函数主要用于返回或者创建(如果有必要的话)一个全局步数的tensor。参数只有一个,就是图,如果没有指定那么就是默认的图。

tf.trainable_variables()

返回所有
trainable=True
的变量。

当我们在声明变量
Variable()
时传入
trainable=True
Variable()
构造函数会自动添加新的变量到图中的集合
GraphKeys.TRAINABLE_VARIABLES
,这个函数实质上就是返回这个集合中的变量。

tensorflow.python.training.moving_averages.assign_moving_average

这个函数的参数如下:

def assign_moving_average(variable, value, decay, zero_debias=True, name=None):


对于
variable
的滑动平均更新为:variable=variable∗decay+value∗(1−decay)variable=variable∗decay+value∗(1−decay)

下面是一个简单的例子(可以看出variable是变量,而value是常量),这个函数主要应用于batch_normalization

def testAssignMovingAverage(self):
with self.test_session():
var = tf.Variable([10.0, 11.0])
val = tf.constant([1.0, 2.0], tf.float32)
decay = 0.25
assign = moving_averages.assign_moving_average(var, val, decay)
tf.global_variables_initializer().run()
self.assertAllClose([10.0, 11.0], var.eval())
assign.op.run()
self.assertAllClose([10.0 * 0.25 + 1.0 * (1.0 - 0.25),
11.0 * 0.25 + 2.0 * (1.0 - 0.25)],
var.eval())
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: