您的位置:首页 > 理论基础 > 计算机网络

RNN学习笔记(一)-简介及BPTT RTRL及Hybrid(FP/BPTT)算法

2015-12-29 16:14 471 查看

RNN学习笔记(一)-简介及BPTT RTRL及Hybrid(FP/BPTT)算法

本文假设读者已经熟悉了常规的神经网络,并且了解了BP算法,如果还不了解的,参见UFIDL的教程。

- 1.RNN结构

- 2.符号定义

- 3.网络unrolled及公式推导

- 4.BPTT

- 5.RTRL

- 6.Hybrid(FP/BPTT)

- 7.参考文献

1.RNN结构

如下图1是一个最简单的RNN:



其中集合I为m个外部输入节点,左下角的U为前一时刻的隐层输出节点,U中的节点数为n,并假定U中所有节点的输出都参与到下一时刻的输入。

2.符号定义

定义:

xi(t):t时刻第i个输入节点的输出值,且i∈I∪U

sk(t):t时刻第k个隐层节点的输出值,且k∈U

yk(t):t时刻第k个输出层节点的输出值,且k∈U

dk(t):t时刻隐层第k个节点的期望输出(即训练数据)

wli:第i个输入到第l个隐层节点的权重,其中i∈I,l∈U

wlk:第k个输入到第l个隐层节点的权重,其中k,l∈U

τ:假定网络的起始时刻为t0,当前时刻为t,t′∈[t0,t),τ∈(t′,t]

y∗k(τ):τ时刻第k个输出节点的输出值,且k∈U,且τ∈(t0,t],对于所有的τ而言,其实有yk(τ)=y∗k(τ),这里之所以引入新符号,是为了避免求导运算时混淆1

再来是一组等式定义:

sk(τ+1)=wx(τ)

ek(t)=dk(t)−yk(t)

J(τ)=∑k∈Uek(t)

Jtotal(t′,t)=∑τ=t′+1tJ(τ),t′∈[t0,t)

ϵk(τ;F)=∂F∂yk(τ)

ek(τ;F)=∂F∂y∗k(τ)

δk(τ;F)=∂F∂sk(τ)

pkij(τ)=∂yk(τ)∂wij

因为假定F只与yk(τ),τ∈(t′,t]显式相关,所以,当τ≤t′时,ek(τ;F)=0。

由于F是任意与yk(t)相关的函数,实际应用中,可以取

F=J(τ);F=Jtotal(t′,t)或其它函数。

因为初始状态的输出yk(t0)为预设值,与w之间不存在函数关系,所以当τ=t0时,pkij(t0)=0。

3.网络unrolled及公式推导

将网络按时间展开:



根据上图,下面两个式子成立:

sk(t+1)=∑l∈Uwklyl(t)+∑l∈Iwklxnetl(t)=∑l∈U∪Iwklxl(t)......(2)

yk(t+1)=fk(sk(t+1))......(3)

显然,y∗k(τ+1),y∗k(τ+2),...,y∗k(t)可以表示成s(τ+1)的函数,因此,

F=F(y∗(t′),y∗(t′+1),...,yk(τ),s(τ+1))=F

下面对公式进行进一步的推导:

ϵk(τ;F)=∂F∂yk(τ)

=∂F(y∗(t′),y∗(t′+1),...,yk(τ),s(τ+1))∂yk(τ)

由复合函数求导法则,上式可进一步变为:

∂F∂y(t′)∂y(t′)∂yk(τ)+∂F∂y(t′+1)∂y(t′+1)∂yk(τ)+...+∂F∂y∗(τ)∂y∗(τ)∂yk(τ)+∂F∂s(τ+1)∂s(τ+1)∂yk(τ)

当τ′<τ时,显然y(τ′)与y(τ)无关,故上式的前半部分为0,即:

ϵk(τ;F)=∂F∂y∗(τ)∂y∗(τ)∂yk(τ)+∂F∂s(τ+1)∂s(τ+1)∂yk(τ)

这里:

∂F∂y∗(τ)=⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢∂F∂y∗1(τ)∂F∂y∗2(τ)...∂F∂y∗k(τ)...∂F∂y∗n(τ)⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥

∂y∗(τ)∂yk(τ)=⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢∂y∗1(τ)∂y∗k(τ)∂y∗2(τ)∂y∗k(τ)...∂y∗k(τ)∂y∗k(τ)...∂y∗n(τ)∂y∗k(τ)⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥=⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢00...1...0⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥

∂F∂s(τ+1)=⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢∂F∂s1(τ+1)∂F∂s2(τ+1)...∂F∂sl(τ+1)...∂F∂sn(τ+1)⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥=⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢δ1(τ+1;F)δ2(τ+1;F)...δl(τ+1;F)...δn(τ+1;F)⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥

∂s(τ+1)∂yk(τ)=⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢∂s∗1(τ+1)∂y∗k(τ)∂s∗2(τ+1)∂y∗k(τ)...∂s∗l(τ+1)∂y∗k(τ)...∂s∗n(τ+1)∂y∗k(τ)⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥=⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢w1kw2k...wlk...wnk⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥

代入,上式可以变为:

ϵk(τ;F)=⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢∂F∂y∗1(τ)∂F∂y∗2(τ)...∂F∂y∗k(τ)...∂F∂y∗n(τ)⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥T⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢00...1...0⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥+⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢δ1(τ+1;F)δ2(τ+1;F)...δl(τ+1;F)...δn(τ+1;F)⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥T⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢w1kw2k...wlk...wnk⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥=∂F∂y∗k(τ)+∑l∈Uwlkδl(τ+1;F)

所以就有:

ϵk(τ;F)=∂F∂y∗k(τ)+∑l∈Uwlkδl(τ+1;F)=ek(τ;F)+∑l∈Uwlkδl(τ+1;F)

因为当τ=t时,ϵk(t;F)=ek(t;F),所以有:

ϵk(τ;F)=⎧⎩⎨⎪⎪ek(t;F) if τ=tek(τ;F)+∑l∈Uwlkδl(τ+1;F) if τ<t

δk(τ;F)=∂F∂sk(τ)=∂F∂yk(τ)∂yk(τ)∂sk(τ)=ϵk(τ;F)f′k(sk(τ))

进一步推导:

ϵk(τ;F)=(ek(τ;F)+∑l∈Uwlkδl(τ+1;F))f′k(sk(τ))

先做如下定义:

wij:第j个输入到第i个隐层节点的权重(迭代更新之前),其中i∈U,j∈U∪I

wij(τ):τ时刻第j个输入到第i个隐层节点的权重(迭代更新之前),其中τ∈[t0,t),i∈U,j∈U∪I

∂F∂wij(τ)=∂F∂si(τ+1)∂si(τ+1)∂wij(τ)=δi(τ+1;F)xj(τ)

∂F∂wij=∑τ=t0t−1∂F∂wij(τ)∂wij(τ)∂wij=∑τ=t0t−1∂F∂wij(τ)=∑τ=t0t−1δi(τ+1;F)xj(τ)

4.BPTT(Back Propagation Through Time)

4.1 Real-Time BPTT

算法描述:

令τ∈(t0,t],k∈U,

ϵk(t)=ek(t),

δk(τ)=f′k(sk(τ))ϵk(τ),

ϵk(τ−1)=∑l∈Uwlkδl(τ),

可以看出,算法的公式与BP算法非常相似,算法从t时刻开始,先用等式ϵk(t)=ek(t)求出ϵk(t),然后再用后边两个等式继续向后迭代,直到t0。这里的第一步也被称为错误注入(injecting error),也说是在t时刻注入了ek(t)。



上图描述了Real-Time BPTT算法在每一个时刻t的存储和处理操作。历史缓存每经过一个时刻t,就会增加一层的数据(包括该t时刻所有的输入和输出值)。实线箭头表明了当前的输出值由和上一时刻的输入输出值确定。虚线表示反向传播,计算直到t0+1的δ。步骤①为injecting error操作,剩下的步骤为每一步的误差计算。

激活函数通常取logistics函数,此时的f′k(sk(τ))=fk(sk(τ))(1−fk(sk(τ)))

最后,权值的梯度通过下式计算:

∂J(t)∂wij=∑τ=t0+1tδi(τ)xj(τ−1)

在每一个时刻t,算法的执行流程如下:

(1)将当前网络的状态和当前的输入值添加到历史缓存2

(2)注入当前时刻t的ek(t),然后在时间区间(t0,t]上进行反向传播,计算出所有的ϵk(τ),δk(τ);

(3)计算所有的∂J(t)∂wij;

(4)根据第(3)步的结果修改权值。

随着时间的增长,算法对历史缓存的需求将是无限的,因此,有时也用BPTT(∞)来表示这个算法,它在理论上的研究价值要远大于实用。接下来,我们将讨论更为实用的近似算法。

4.2 Epochwise BPTT

为了解决Real-Time BPTT对内存的无限制需求,我们采用一种近似的算法,即:Epochwise BPTT。

算法的目标是计算基于Jtotal(t0,t1)的梯度(即损失函数F=Jtotal(t0,t1)),其步骤跟前边类似。同样的,

令τ∈(t0,t1],k∈U,

ϵk(t1)=ek(t1),

δk(τ)=f′k(sk(τ))ϵk(τ),

ϵk(τ−1)=ek(τ−1)+∑l∈Uwlkδl(τ),

算法从最后的时刻t1开始,injecting error ek(t1),然后运用后边两个等式,迭代计算δk(τ),ϵk(τ−1),直到τ=t0+1。此时权值的梯度按下式计算:

∂Jtotal(t0,t1)∂wij=∑τ=t0+1t1δi(τ)xj(τ−1)



对[t0,t1]中所有的输入输出以及目标值都被存储在历史缓存中。实线表示输出由上一时刻的输入和输出确定,当一次epoch完成后,执行BP操作(虚线箭头)。奇数索引的步骤表示error injection,偶数索引的步骤表示误差(δ)传播。一旦BP操作完成,每个权值的梯度就可以算出来了。

算法的执行流程如下:

(1)执行BP算法,计算所有的ϵk(τ),δk(τ),τ∈(t0,t1];

(2)计算所有的∂Jtotal(t0,t1)∂wij;

(3)使用(2)的结果更新权值,重复步骤(1)~(3);

5.RTRL(Real-Time Recurrent Learning)

与反向传播的BPTT算法不同的是,RTRL通过前向传播梯度来进行计算。

对任意的k∈U,i∈U,j∈U∪I,以及t∈[t0,t1],定义:

pkij(t)=∂yk(t)∂wij

令F=J(t),有:

∂J(t)∂wij=∑k∈Uek(t)pkij(t)

根据之前的关系等式:

sk(t+1)=∑l∈Uwklyl(t)+∑l∈Iwklxnetl(t)=∑l∈U∪Iwklxl(t)......(2)

yk(t+1)=fk(sk(t+1))......(3)

可以推出:

pkij(t+1)=∂yk(t+1)∂wij=∂yk(t+1)∂sk(t+1)∂sk(t+1)∂wij=f′k(sk(t+1))[∑l∈Uwklplij(t)+δikxj(t)]3

此外,t0时刻的输出为预设值,与连接权值无关,所以有:

pkij(t0)=∂yk(t0)∂wij=0

于是,整个计算过程将从t=t0开始迭代计算,直到t=t1。

对每一个时刻t,计算相应的yk(t)以及∂J(t)∂wij

6.Hybrid(FP/BPTT)

∂F∂wij=∑τ=t0t′−1∂F∂wij(τ)+∑τ=t′t−1∂F∂wij(τ)

等式右边的第一部分可写为:

∑τ=t0t′−1∂F∂wij(τ)=∑τ=t0t′−1∑l∈U∂F∂yl(t′)∂yl(t′)∂wij(τ)=∑l∈U∂F∂yl(t′)∑τ=t0t′−1∂yl(t′)∂wij(τ)=∑l∈U∂F∂yl(t′)∂yl(t′)∂wij=∑l∈Uϵl(t′;F)plij(t′)

因此,最初的式子可变为:

∂F∂wij=∑l∈Uϵl(t′;F)plij(t′)+∑τ=t′t−1δi(τ+1;F)xj(τ)

令F=Jtotal(t′,t)

∂Jtotal(t′,t)∂wij=∑l∈Uϵl(t′)plij(t′)+∑τ=t′t−1δi(τ+1)xj(τ)

首先计算BPTT:

ϵk(τ)=⎧⎩⎨⎪⎪δkr if τ=t∑l∈Uwlkδl(τ+1) if τ<t

然后,使用上边的计算结果执行:

prij(t)=∑l∈Uϵl(t′)plij(t′)+∑τ=t′t−1δl(τ+1)xj(τ)



上图是FP/BPTT(h)算法的简单描述。可以看到,算法包含两个连续的误差计算过程。一个在时刻t,另一个在时刻t+h.从时刻t−h直到时刻t的输入、输出和目标值都存储在历史缓存中。

7.参考文献

1.Gradient-Based Learning Algorithms for Recurrent Networks and Their Computational Complexity.Ronald J. Williams,David Zipser

F:F为{yk(τ)|k∈U,τ∈(t′,t]}的函数,

即F=F(yk(t′+1),yk(t′+2),...,yk(τ),...,yk(t))

这地方稍微深入说明一下引入变量y∗k(τ)的原因:

假设有函数f(x,y)=x+2y,同时,y,x满足:y=x2

对f(x,y)求偏导数:∂f∂x=∂(x+2y)∂x

这个地方出现了两个x(分别在分式的上下边),这两个x虽然相等,但含义其实并不相同。下边的x是自变量,上边的x其实可以看做自变量的一个函数,不妨令t=x,于是有如下关系式:

{x(t)=ty(t)=t2

于是f(x,y)=f(x(t),y(t))

∂f∂x=∂f(x(t),y(t))∂t

由复合函数求导法则,上式又可变为:

∂f(x(t),y(t))∂x(t)∂x(t)∂t+∂f(x(t),y(t))∂y(t)∂y(t)∂t

由于x(t),y(t)是t的单变量函数,有:

∂x(t)∂t=dx(t)dt

∂y(t)∂t=dy(t)dt

所以有:

∂f∂x=∂f(x(t),y(t))∂x(t)dx(t)dt+∂f(x(t),y(t))∂y(t)dy(t)dt

类比函数即F=F(y(t′+1),y(t′+2),...,yk(τ),...,y(t)),对其求关于yk(τ)的偏导数显然也存在符号混淆的问题,所以,有必要引入符号

y∗k(τ)=y∗k(τ)(yk(τ))=yk(τ)

y∗k(τ)(yk(τ))后边的括号表示y∗k(τ)为yk(τ)的函数。变量符号y∗k(τ)的意义与上例中x(t)的意义一样。
历史缓存(History buffer)中存储了整个网络从t0时刻开始的输入和激活信息。
δik是克罗内克函数(Kronecker delta)

函数定义:

δik={1 if i=k0 if i≠k
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息