您的位置:首页 > 其它

指数滑动平均(ExponentialMovingAverage)EMA

2017-12-10 21:23 323 查看
EMA被广泛的应用在深度学习的BN层中,RMSprop,adadelta,adam等梯度下降方法

tf.train.ExponentialMovingAverage

函数定义

tensorflow中提供了tf.train.ExponentialMovingAverage来实现滑动平均模型,他使用指数衰减来计算变量的移动平均值。

tf.train.ExponentialMovingAverage.init(self, decay, num_updates=None, zero_debias=False, name="ExponentialMovingAverage"):

decay是衰减率在创建ExponentialMovingAverage对象时,需指定衰减率(decay),用于控制模型的更新速度。decay设置为接近1的值比较合理,通常为:0.999,0.9999。这里的一个trick是,



例如,

0.95^(20)=0.3584

1/e=0.3678

两者大概是近似相等的,也许这就是指数滑动平均中指数的含义吧。

影子变量的初始值与训练变量的初始值相同。当运行变量更新时,每个影子变量都会更新为:



num_updates是ExponentialMovingAverage提供用来动态设置decay的参数,当初始化时提供了参数,即不为none时,每次的衰减率是:



apply()方法添加了训练变量的影子副本,并保持了其影子副本中训练变量的移动平均值操作。在每次训练之后调用此操作,更新移动平均值。

average()和average_name()方法可以获取影子变量及其名称。

Tensorflow栗子:

import tensorflow as tf

# 定义一个32位浮点数的变量,初始值位0.0
v1 =tf.Variable(dtype=tf.float32, initial_value=0.)

# 衰减率decay,初始值位0.99
decay = 0.99

# 定义num_updates,同样,初始值位0
num_updates = tf.Variable(0, trainable=False)

# 定义滑动平均模型的类,将衰减率decay和num_updates传入。
ema = tf.train.ExponentialMovingAverage(decay=decay, num_updates=num_updates)

# 定义更新变量列表
update_var_list = [v1]

# 使用滑动平均模型
ema_apply = ema.apply(update_var_list)

# Tensorflow会话
with tf.Session() as sess:
# 初始化全局变量
sess.run(tf.global_variables_initializer())

# 输出初始值
print(sess.run([v1, ema.average(v1)]))
# [0.0, 0.0](此时 num_updates = 0 ⇒ decay = .1, ),
# shadow_variable = variable = 0.

# 将v1赋值为5
sess.run(tf.assign(v1, 5))

# 调用函数,使用滑动平均模型
sess.run(ema_apply)

# 再次输出
print(sess.run([v1, ema.average(v1)]))
# 此时,num_updates = 0 ⇒ decay =0.1, v1 = 5;
# shadow_variable = 0.1 * 0 + 0.9 * 5 = 4.5 ⇒ variable

# 将num_updates赋值为10000
sess.run(tf.assign(num_updates, 10000))

# 将v1赋值为10
sess.run(tf.assign(v1, 10))

# 调用函数,使用滑动平均模型
sess.run(ema_apply)

# 输出
print(sess.run([v1, ema.average(v1)]))
# decay = 0.99,shadow_variable = 0.99 * 4.5 + .01*10 ⇒ 4.555

# 再次使用滑动平均模型
sess.run(ema_apply)

# 输出
print(sess.run([v1, ema.average(v1)]))
# decay = 0.99,shadow_variable = .99*4.555 + .01*10 = 4.609
for i in range(1000):
sess.run(ema_apply)
print(sess.run([v1,ema.average(v1)]))
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: