您的位置:首页 > 其它

从NN到RNN再到LSTM(附模型描述及详细推导)——(三)LSTM

2015-07-02 11:48 113 查看
作者是NLP的初学者,经由导师指导,稍微学习了解了一下LSTM的网络模型及求导,打算在这里分享一下,欢迎各种交流。

转载请注明出处:/article/3711224.html

以下内容主要引自Alex Graves写的Supervised Sequence Labelling with Recurrent Neural Networks 一书,并加上了个人的理解进行阐述。

http://www.springer.com/cn/book/9783642247965

前面讲的RNN,虽然能保存历史信息,但是RNN存在梯度衰减问题(vanishing gradient problem)




The problem is that the influence of a given input on the hidden layer, and therefore on the network output, either decays or blows up exponentially as it cycles around the network’s recurrent connections.

随着RNN隐层深度的增加,也就是随着时间的增加,梯度计算成指数倍增长或衰减。而且随着时间的增加,最早的历史信息可能被遗忘了(因为隐层激活值逐渐被新的输入覆盖)。这就有了长短时记忆(Long Short-Term Memory, Hochreiter and Schmidhuber, 1997)。LSTM也是一个循环神经网络,它是在RNN的基础上,为隐层每个节点开了三扇门:

1)Input Gate:表示是否允许信息加入到当前隐层节点中,如果为1(门开),则允许输入,如果为0(门关),则不允许,这样就可以摒弃掉一些没用的输入信息;

2)Forget Gate:表示是否保留当前隐层节点存储的历史信息,如果为1(门开),则保留,如果为0(门关),则清空当前节点所存储的历史信息;

3)Output Gate:表示是否将当前节点输出值输出给下一层(下一个隐层或者输出层),如果为1(门开),则当前节点的输出值将作用于下一层,如果为0(门关),则否。

三扇门用下面三个符号表示:ιϕω -- Input Gate -- Forget Gate -- Output Gate}∈(0,1)
\begin{array}{l}
\iota & \text{ -- Input Gate} \\
\phi & \text{ -- Forget Gate} \\
\omega & \text{ -- Output Gate}
\end{array} \} \in \left(0, 1\right) 很诡异,但是书上就这么写,我们也就这么看,可以不用记住,后面涉及到公式内容的时候回来查询就好了。

首先来回顾一下RNN的隐层计算公式:




隐层节点激活前后的值也是分开写,这里就有几点我的见解,也算是赘述:

1)ath: the network input to unit h at time ta_h^t \text{: the network input to unit }h\text{ at time }t .

2)bth: the activation of unit h at time tb_h^t \text{: the activation of unit }h\text{ at time }t .

3)bth: the output of unit h at time t\color{red}{b_h^t \text{: the output of unit }h\text{ at time }t}.

4)bt−1h′,∀h′: the history of unit h at time t\color{red}{b_{h’}^{t-1}, \forall h’ \text{: the history of unit }h\text{ at time }t}.

公式(1)就是atha_h^t的计算式,这个值将“输入”给隐层节点进行激活,这个激活(2)后的值就是bthb_h^t,也可以认为是这个隐层节点的“输出值”,就是要输出给下一个隐层或者输出层。在tt时间,历史信息就保存于t−1t-1时间的隐层中,也就是bt−1h′,∀h′b_{h’}^{t-1}, \forall h’,也是前一个隐层的“输出”,而bth,∀hb_h^t, \forall h也就是下一个隐层的历史信息。

LSTM就是将RNN的每个隐层节点变换成如下图右边部分这么一大块东西。







对应RNN隐层公式(1)和(2),LSTM这个大节点的计算公式如(3)(4)(5)所示。RNN的隐层节点的激活值就是输出值,也就是历史信息,而在LSTM中,一个隐层的“大节点”(叫做block)有一个中心cell,这个cell的state存储了历史信息,而整个“大节点”的“输出”(unit output)是这个state经过Output Gate变化后的值:式子(3)这样写只是为了对应式子(1),同样,cell的输入值包括输入向量xtx^t和前一个隐层的unit output。式子(4)将计算这个cell在tt时间的state(下标cc表示cell),包括两个部分:一个是经过Input Gate(ι)\left( \iota \right)变换的激活值,二是经过Forget Gate(ϕ)\left( \phi \right)变换的t−1t-1时间的cell的state(历史信息),这两部分加在一起将作为t+1t+1时间的历史信息。式子(5)则对cell state进行激活,并经过Output Gate(ω)\left( \omega \right)变换,得到的结果作为unit output。

在书中有一句话“Each block contains one or more self-connected memory cells and three multiplicative units - the input, output and forget gates – that provide continuous analogues of write, read and reset operations for the cells.”也就是上面那个一个大节点,里面可能也会有多个cell,那么记每个block中有CC个cell,并用下标cc指示。下面给出其它符号和具体公式。

–ι\iota, ϕ\phi, ω\omega分别表示Input Gate, Forget Gate, Output Gate。

–wijw_{ij}:从节点ii到节点jj的权重。那么从cell到三扇门的peephole weights分别为wcιw_{c\iota},wcϕw_{c\phi}和wcωw_{c\omega}。

–atja_j^t:节点jj在tt时间的输入值。

–btjb_j^t:节点jj在tt时间的激活值。

–stcs_c^t:cell cc在tt时间的state。

–ff:三扇门的激活函数,通常是sigmoid函数,这样能使激活值变换到(0,1)区间,刚好0表示门关,1表示门开。

–gg和hh分别表示cell的输入和输出的激活函数,常为tanh或者sigmoid。

–II,HH和KK则分别表示输入层的节点数,和隐层的Cell个数,输出层的节点数。

–并且只有cell output btcb_c^t会连接到隐层的其他“大节点”,而cell state,cell inputs,gate activations仅在“大节点”内可见。

前馈过程,这里我配合(Hochreiter and Schmidhuber, 1997)原文里的前馈过程,书中有如下一幅图:




图中,输入层有4个节点,输出层有5个节点,隐层有2个block,每个block有1个cell(C=1C=1,如果有多个cell,则重复公式即可),引用此图,可以根据图中的线,区分每个变量,以及这些变量从哪里来,要到哪里去。

InputInput Gate:Gate: 输入有三个来源:输入层节点,前一个隐层的cell outputs,前一个时间的cell states。输出就是一个0~1的值,表示门开或关(最下方的实心黑点)。 atι=∑i=1Iwiιxti+∑h=1Hwhιbt−1h+∑c=1Cwcιst−1c(6)a_\iota^t = \displaystyle\sum_{i=1}^Iw_{i\iota}x_i^t + \displaystyle\sum_{h=1}^Hw_{h\iota}b_h^{t-1} + \displaystyle\sum_{c=1}^Cw_{c\iota}s_c^{t-1} \tag{6} btι=f(atι)(7)b_\iota^t = f\left(a_\iota^t\right) \tag{7} ForgetForget Gate:Gate: 输入也有三个来源,同InputInput GateGate相同,输出也是一个0~1的值(中间的实心黑点)。 atϕ=∑i=1Iwiϕxti+∑h=1Hwhϕbt−1h+∑c=1Cwcϕst−1c(8)a_\phi^t = \displaystyle\sum_{i=1}^Iw_{i\phi}x_i^t + \displaystyle\sum_{h=1}^Hw_{h\phi}b_h^{t-1} + \displaystyle\sum_{c=1}^Cw_{c\phi}s_c^{t-1} \tag{8} btϕ=f(atϕ)(9)b_\phi^t = f\left(a_\phi^t\right) \tag{9} Cells:Cells: 这里计算cell states。最首先输入有两个来源(block最下方):输入层节点和前一个隐层的cell outputs。中心cell将根据InputInput GateGate判断是否将这个输入加入到state中,同时根据ForgetForget GateGate判断是否保留过去的state(见公式11)。 atc=∑i=1Iwicxti+∑h=1Hwhcbt−1h(10)a_c^t = \displaystyle\sum_{i=1}^Iw_{ic}x_i^t + \displaystyle\sum_{h=1}^Hw_{hc}b_h^{t-1} \tag{10} stc=btϕst−1c+btιg(atc)(11)\color{red}{ s_c^t = b_\phi^ts_c^{t-1} + b_\iota^tg\left(a_c^t\right) \tag{11}} OutputOutput Gate:Gate: 输入来源有三个:输入层节点,前一个隐层的cell outputs,当前时间的cell states。输出就是一个0~1的值(最上方的实心黑点)。 atω=∑i=1Iwiωxti+∑h=1Hwhωbt−1h+∑c=1Cwcωstc(12)a_\omega^t = \displaystyle\sum_{i=1}^Iw_{i\omega}x_i^t + \displaystyle\sum_{h=1}^Hw_{h\omega}b_h^{t-1} + \displaystyle\sum_{c=1}^Cw_{c\omega}s_c^t \tag{12} btω=f(atω)(13)b_\omega^t = f\left(a_\omega^t\right) \tag{13} CellCell Outputs:Outputs: btc=btωh(stc)(14)b_c^t = b_\omega^th\left(s_c^t\right) \tag{14} (下面是个人见解)这里的CC和HH推敲了很久,HH表示的是隐层的cell个数(书中说的,就是cell的总数了,而非隐层的block个数),而一个block又有CC个cell,从前馈过程的各个式子中可以看出一个block中的所有cell都共享三个Gates,有几个cell,整个block就有几个输出,所以作者直接用HH表示全部的cell数更简便,应该是这样。而且需要注意,cell states仅在block内可见,因此式子(6)(8)(12)中的scs_c都是同一个block在不同时间的cell states,而非隐层所有的cell states。

最后从隐层到输出层的公式可以为(这条式子书中没有,是我根据后面梯度的式子反过来推的):atk=∑c=1Hwckbtc(15)a_k^t=\displaystyle\sum_{c=1}^Hw_{ck}b_c^t \tag{15}

反向传播:

定义:δtj=def∂∂atj\delta_j^t \quad {\overset{\text{def}}=} \quad {\cfrac{\partial{\mathcal{L}}}{\partial{a_j^t}}} ϵtc=def∂∂btcϵts=def∂∂stc\epsilon_c^t \quad {\overset{\text{def}}=} \quad {\cfrac{\partial{\mathcal{L}}}{\partial{b_c^t}}} \qquad \epsilon_s^t \quad {\overset{\text{def}}=} \quad {\cfrac{\partial{\mathcal{L}}}{\partial{s_c^t}}} CellCell Outputs:Outputs:

ϵtc=∑k=1K∂∂atk∂atk∂btc+∑h=1H∂∂at+1h∂at+1h∂btc=∑k=1Kwckδtk+∑h=1Hwchδt+1h(16)\begin{align}
\epsilon_c^t & = \displaystyle\sum_{k=1}^K {\cfrac{\partial{\mathcal{L}}}{\partial{a_k^t}}} {\cfrac{\partial{a_k^t}}{\partial{b_c^t}}} + \color{red}{\displaystyle\sum_{h=1}^H{\cfrac{\partial{\mathcal{L}}}{\partial{a_h^{t+1}}}} {\cfrac{\partial{a_h^{t+1}}}{\partial{b_c^t}}}} \tag{16} \\
& = \displaystyle\sum_{k=1}^Kw_{ck}\delta_k^t \color{red} {+ \displaystyle\sum_{h=1}^Hw_{ch}\delta_h^{t+1}}
\end{align} OutputOutput Gates:Gates:

δtω=∂btω∂atω∑c=1C∂∂btc∂btc∂btω=f′(atω)∑c=1Ch(stc)ϵtc(17)\begin{align}
\delta_\omega^t & = {\cfrac{\partial{b_\omega^t}}{\partial{a_\omega^t}}} \displaystyle\sum_{c=1}^C {\cfrac{\partial{\mathcal{L}}}{\partial{b_c^t}}} {\cfrac{\partial{b_c^t}}{\partial{b_\omega^t}}} \tag{17} \\
& = f’\left(a_\omega^t\right)\displaystyle\sum_{c=1}^Ch\left(s_c^t\right)\epsilon_c^t
\end{align} States:States: scs_c 这个变量在式子(6)(8)(11)(12)(14)的右边都有出现,因此要依次反向传播回来。

ϵts=∂∂btc∂btc∂stc(14)+∂∂atω∂atω∂stc(12)+∂∂st+1c∂st+1c∂stc(11)+∂∂at+1ϕ∂at+1ϕ∂stc(8)+∂∂at+1ι∂at+1ι∂stc(6)=ϵtcbtωh′(stc)+wcωδtω+bt+1ϕϵt+1s+wcϕδt+1ϕ+wcιδt+1ι(18)\begin{align}
\epsilon_s^t & = {\cfrac{\partial{\mathcal{L}}}{\partial{b_c^t}}} {\cfrac{\partial{b_c^t}}{\partial{s_c^t}}}\color{red}{\left(14\right)} + {\cfrac{\partial{\mathcal{L}}}{\partial{a_\omega^{t}}}} {\cfrac{\partial{a_\omega^{t}}}{\partial{s_c^t}}}\color{red}{\left(12\right)} + {\cfrac{\partial{\mathcal{L}}}{\partial{s_c^{t+1}}}} {\cfrac{\partial{s_c^{t+1}}}{\partial{s_c^t}}}\color{red}{\left(11\right)} \\
& + {\cfrac{\partial{\mathcal{L}}}{\partial{a_\phi^{t+1}}}} {\cfrac{\partial{a_\phi^{t+1}}}{\partial{s_c^t}}}\color{red}{\left(8\right)} + {\cfrac{\partial{\mathcal{L}}}{\partial{a_\iota^{t+1}}}} {\cfrac{\partial{a_\iota^{t+1}}}{\partial{s_c^t}}}\color{red}{\left(6\right)} \\
& = \epsilon_c^tb_\omega^th’\left(s_c^t\right) + w_{c\omega}\delta_\omega^t + b_\phi^{t+1}\epsilon_s^{t+1} + w_{c\phi}\delta_\phi^{t+1} + w_{c\iota}\delta_\iota^{t+1} \tag{18}
\end{align} Cells:Cells:

δtc=∂∂stc∂stc∂atc=ϵtsbtιg′(atc)(19)\begin{align}
\delta_c^t & = {\cfrac{\partial{\mathcal{L}}}{\partial{s_c^t}}}
{\cfrac{\partial{s_c^t}}{\partial{a_c^t}}} \tag{19} \\
& = \epsilon_s^tb_\iota^tg’\left(a_c^t\right)
\end{align} ForgetForget Gates:Gates:

δtϕ=∂btϕ∂atϕ∑c=1C∂∂stc∂stc∂btϕ=f′(atϕ)∑c=1Cϵtsst−1c(20)\begin{align}
\delta_\phi^t & = {\cfrac{\partial{b_\phi^t}}{\partial{a_\phi^t}}} \displaystyle\sum_{c=1}^C {\cfrac{\partial{\mathcal{L}}}{\partial{s_c^t}}} {\cfrac{\partial{s_c^t}}{\partial{b_\phi^t}}} \tag{20} \\
& = f’\left(a_\phi^t\right) \displaystyle\sum_{c=1}^C\epsilon_s^ts_c^{t-1}
\end{align} InputInput Gates:Gates:

δtι=∂btι∂atι∑c=1C∂∂stc∂stc∂btι=f′(atι)∑c=1Cϵtsg(atc)(21)\begin{align}
\delta_\iota^t & = {\cfrac{\partial{b_\iota^t}}{\partial{a_\iota^t}}} \displaystyle\sum_{c=1}^C {\cfrac{\partial{\mathcal{L}}}{\partial{s_c^t}}} {\cfrac{\partial{s_c^t}}{\partial{b_\iota^t}}} \tag{21} \\
& = f’\left(a_\iota^t\right) \displaystyle\sum_{c=1}^C\epsilon_s^tg\left(a_c^t\right)
\end{align}

以上,就是LSTM的前馈和反向传播的过程。略复杂。

微软的俞栋领导了一个项目,实现了一个开源工具CNTK(Computational Network Toolkit),这个工具可以非常简单地实现网络构建和测试过程,包括DNN,CNN,RNN,LSTM,最大熵模型等。截至2015年7月5日,当前版本是CNTK(Windows+Linux) 2015-04-15。

CNTK: http://cntk.codeplex.com/
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: