您的位置:首页 > 理论基础 > 计算机网络

RNN学习笔记(二)-Gradient Analysis

2016-01-07 15:47 218 查看
、# RNN学习笔记(二)-Gradient Analysis

RNN网络具有对时间序列建模的特性,其在时间轴上可以展开成一个多层前馈网络,因此也存在多层网络同样的问题,随着网络递归层数的增加,误差梯度的传递将出现不稳定的情况(消散或膨胀),下边将进行深入分析:

1.BPTT算法回顾及符号定义

2.误差传导分析

3.全局的误差传导分析

4.参考文章

1.BPTT算法回顾及符号定义

δk(t)=∂J(t)∂sk(t)=f′k(sk(t))(dk(t)−yk(t))

yk(t)=fk(sk(t))

sk(t)=∑j∈Uwijyj(t−1)

δj(t)=f′j(sj(t))∑i∈Uwijδi(t+1)

2.误差传导分析

下面,对δ(t)进行深入分析。

∂δv(t−q)∂δu(t)=⎧⎩⎨⎪⎪⎪⎪⎪⎪⎪⎪f′v(sv(t−1))wuv if q=1f′v(sv(t−q))∑l∈U∂δl(t−q+1)∂δu(t)wlv if q>1

针对q>1的情形进行展开:

f′v(sv(t−q))∑l∈U∂δl(t−q+1)∂δu(t)wlv=f′v(sv(t−q))∑l∈U∂[f′l(sl(t−q+1))∑l′∈Uwl′lδl′(t−q+2)]∂δu(t)wlv

为了符号表示上的便利,我们引入新的符号:

lm:第m次迭代的求和下标变量,且对所有的lm满足lm∈U。同时,为了书写简洁,我们在求和表达式中省略这个限制条件。

m:从时刻t开始,向后迭代的次数。

根据定义,有:l0=u,lq=v,

于是,上式可以改写为:

f′lq(slq(t−q))∑lq−1∂δlq−1(t−q+1)∂δl0(t)wlq−1lq=f′lq(slq(t−q))∑lq−1∂[f′lq−1(slq−1(t−q+1))∑lq−2wlq−2lq−1δlq−2(t−q+2)]∂δl0(t)wlq−1lq

紧接上式继续推导:

f′lq(slq(t−q))∑lq−1f′lq−1(slq−1(t−q+1))wlq−1lq∑lq−2wlq−2lq−1∂δlq−2(t−q+2)∂δl0(t)

进一步化简得:

f′lq(slq(t−q))∑lq−1∑lq−2f′lq−1(slq−1(t−q+1))wlq−1lqwlq−2lq−1∂δlq−2(t−q+2)∂δl0(t)

把其中的δlq−2(t−q+2)用δlq−3(t−q+3)继续展开化简得:

f′lq(slq(t−q))∑lq−1∑lq−2∑lq−3f′lq−1(slq−1(t−q+1))f′lq−2(slq−2(t−q+2))wlq−1lqwlq−2lq−1wlq−3lq−2∂δlq−3(t−q+3)∂δl0(t)

继续往下展开并化简,直到l1:

f′lq(slq(t−q))∑lq−1⋯∑l1f′lq−1(slq−1(t−q+1))⋯f′l2(sl2(t−2))wlq−1lq⋯wl1l2∂δl1(t−1)∂δl0(t)

按本节最开始给出的式子,当q=1时,对上式最后边的偏导运算进行化简:

f′lq(slq(t−q))∑lq−1⋯∑l1f′lq−1(slq−1(t−q+1))⋯f′l2(sl2(t−2))wlq−1lq⋯wl1l2×f′l1(sl1(t−1))wl0l1

使用变量m代替连乘式l的下标,则上式可以改写为:

f′lq(slq(t−q))∑lq−1⋯∑l1∏m=1q−1f′lm(slm(t−m))wlm−1lm

把上式左侧的f′lq(slq(t−q))写入连乘式,并调整求和号顺序:

Dvu(t,q)=∂δv(t−q)∂δu(t)=∂δlq(t−q)∂δl0(t)=∑l1⋯∑lq−1∏m=1qf′lm(slm(t−m))wlm−1lm

可以看出,后边的乘法项部分是影响偏导数值的关键。下边,我们重点考察这一部分:Δm=|f′lm(slm(t−m))wlm−1lm|。假设隐层节点数为n,则求和运算中一共有nq−1个形如∏m=1qf′lm(slm(t−m))wlm−1lm的项。

1.当Δm>1.0时,显然Dvu(t,q)将呈指数增加;

2.当Δm<1.0时,显然Dvu(t,q)将呈指数减小(gradient vanishes);

3.当Δm=1.0时,显然Dvu(t,q)将呈线性变化;

假设flm是sigmoid函数,易知max(f′lm)=0.251.

也就是说要想让Δm≥1.0,必然要有|wlmlm+1|≥4。

设ylm−1为不等于0的常数,当Δm取得最大值时,

f′lm(slm(t−m))=0.25

slm(t−m)=∑lm+1wlmlm+1ylm+1(t−m+1)=0

wlm−1lm=1ylm−1coth(12slm)2

首先要说明一点,RNN网络按时间展开后,其隐藏层的权值是共享的,即对于不同时刻的权值,有wij(t1)=wij(t2)(这一部分理解还不够深刻,有错误之处欢迎指正)。

当slm(t−m)=0时,显然有|coth(12slm)|→∞

Δm=|f′lm(slm(t−m))wlm−1lm|=|f′lm(slm(t−m))||wlm−1lm|

这里把wlm−1lm看成变量,为了简化符号,做如下定义:

w:=wlm−1lm

x:=slm(t−m)

g(x):=f′lm(slm(t−m))

于是,

Δm=|g(x)||w|=|1ex+e−x+2||w|=|wex+e−x+2|

假设上一时刻的输出ylm+1为不等于0的常数,由x=slm(t−m)=∑lm+1ylm+1wlmlm+1,由于隐藏层的权值共享,可以看出,x其实是w的线性函数。所以当w以线性方式趋于无穷时,Δm的分母将以指数增长的方式趋于无穷大,因此有:

w→∞,Δm→0,这个结论也说明了,不能以单纯增大w的初始值的方式来避免梯度消散的问题。因为一味的增大w反而会使后向传递的误差变得更小。

只要|wlm−1lm|<4,必然就有Δm<1,就存在梯度消散的问题。

3.全局的误差传导分析

接下来定义一些新的符号:

n:隐层的节点数;

W:权值矩阵,[W]ij:=wij;

Wv:输出权值向量,[Wv]i:=[W]iv=wiv;

WuT:输入权值向 量,[WuT]i:=[W]ui=wui;

gi(m):f‘i(si(t−m)),第t−m时刻隐层第i个节点激活函数的导数值;

F′(t−m):对角矩阵,即[F′(t−m)]ij:=0,if i≠j;[F′(t−m)]ij:=gi(m),if i=j;

[A]ij:矩阵A第i列第j行的元素;

[x]i:向量x的第i个元素;

∥⋅∥A:矩阵A的范数;

∥x∥x:向量x的范数;

f′max:=maxm=1,...,q{∥F′(t−m)∥A};

ek:[e]k=1,其余元素为0的向量。

从第2节中的结果开始:

Dvu(t,q)=∑l1⋯∑lq−1∏m=1qf′lm(slm(t−m))wlm−1lm

=∑lq−1⋯∑l1f′l1(sl1(t−1))wl0l1f′l2(sl2(t−2))wl1l2⋯f′lq−2(slq−2(t−q+2))wlq−3lq−2f′lq−1(slq−1(t−q+1))wlq−2lq−1×f′lq(slq(t−q))wlq−1lq

为了方便讨论,设q=4,n=2,代入得:

∑l3∑l2∑l1f′l1(sl1(t−1))wl0l1f′l2(sl2(t−2))wl1l2f′l3(sl3(t−3))wl2l3f′l4(sl4(t−4))wl3l4

=∑l3=1n∑l2=1n∑l1=1nf′l1(sl1(t−1))wul1f′l2(sl2(t−2))wl1l2f′l3(sl3(t−3))wl2l3f′v(sv(t−4))

=∑l3=1n∑l2=1n∑l1=1ngl1(1)wul1gl2(2)wl1l2gl3(3)wl2l3gv(4)wl3v

=[wu1 wu2][g1(1) 00 g2(1)][g1(2) 00 g2(2)][w11 w21w12 w22][g1(3) 00 g2(3)][w11 w21w12 w22][w1vw2v]gv(4)

=(WuT)F′(t−1)∏m=23(F′(t−m)W)Wvf′v(sv(t−q))

所以,原式可以展开为:

Dvu(t,q)=(WuT)⎡⎣⎢⎢g1(1)⋱gn(1)⎤⎦⎥⎥⎡⎣⎢⎢g1(2)⋱gn(2)⎤⎦⎥⎥⎡⎣⎢⎢w11⋮w1n⋯⋯wn1⋮wnn⎤⎦⎥⎥×⋯×⎡⎣⎢⎢g1(q−1)⋱gn(q−1)⎤⎦⎥⎥⎡⎣⎢⎢w11⋮w1n⋯⋯wn1⋮wnn⎤⎦⎥⎥Wvf′v(sv(t−q))

进一步化简得:

(WuT)F′(t−1)∏m=2q−1(F′(t−m)W)Wvf′v(sv(t−q))3

设maxm=1,...,n{|xi|}≤∥x∥x,

则必有|xTy|≤n∥x∥x∥y∥x

因此,f′v(sv(t−q))≤∥F′(t−q)∥A≤f′max

|Dvu(t,q)|≤n(f′max)q∥Wv∥x∥WTu∥x(∥W∥A)q−2q

因为:

∥Wv∥x=∥Wev∥A≤∥W∥A∥ev∥x≤∥W∥A

∥WTu∥x=∥euW∥A≤∥eu∥x∥W∥A≤∥W∥A

所以,有:

|Dvu(t,q)|≤n(f′max∥W∥A)q

范数有多种计算方法,这里可以采用如下方式计算:

∥W∥A:=maxr∑s|wrs|(取和最大的行)

∥x∥x:=maxr|xr|(取绝对值最大的元素)

当f′max=0.25, 如果下式成立:

wij≤wmax≤4.0n,∀i,j

则有∥W∥A≤nwmax≤4.0,令τ:=(nwmax4.0)<1.0

有:

|Dvu(t,q)|≤n(τ)q,随着q的增大,该式将指数衰减。

4.参考文献

1.LONG SHORT-TERM MEMORY,Neural Computation 9(8):1735-1780, 1997.Sepp Hochreiter,Jurgen Schmidhuber

f(x)=11+e−x

f′(x)=f(x)(1−f(x))=11+e−xe−x1+e−x

化简得:

f′(x)=e−x1+e−2x+2e−x

分式上下同乘以ex,得:

g(x)=f′(x)=1ex+e−x+2

要求g(x)的最大值,先对g(x)求导:

g′(x)=−ex−e−x(ex+e−x+2)2

分母恒大于0,只需要考察分子K=−(ex−e−x)

显然,当x>0时,K恒小于0,g(x)单减;

当x<0时,K恒大于0,g(x)单增;

当x=0时,K=0,此时必然为g(x)最大值

max(g(x))=g(0)=1e0+e−0+2=12+2=14
coth(x)=1tanh=ex+e−xex−e−x,双曲函数



这里F′(t−m)W跟paper上的顺序刚好相反,原因是wlm−1lm的下标与paper上相反(详情参考参文献中的1.)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息