您的位置:首页 > 理论基础 > 计算机网络

循环神经网络与自然语言处理

2018-09-28 20:23 260 查看

循环神经网络

人工神经网络和卷积神经网络的一个共同特点是,输出只依赖于输入。这在某些情况下不符合事实。在循环神经网络(RNN-Recurrent Neural Network)的应用场景中,输出不仅依赖于输入,而且依赖于“记忆”。打个比方说,我们人类的学习,不仅依赖于新知识的输入,而且依赖于我们已有的知识。

循环神经网络的网络结构

循环神经网络的网络结构如下:

其中每个圆圈可以看作是一个单元,而且每个单元做的事情也是一样的,因此可以折叠呈左半图的样子。用一句话解释RNN,就是一个单元结构重复使用。

在图中,xt表示时刻t的输入,st表示时刻t的状态,ot表示时刻t的输出。

其中:

St=f(W∗St−1+U∗xt) S_t = f(W*S_{t-1} + U*x_t)St​=f(W∗St−1​+U∗xt​)
也就是说,时刻t的状态是由当前时刻的输入和上一个时刻的状态(加上一定的权重)决定的。其中的f是激活函数。

ot=softmax(V∗st) o_t = softmax(V*s_t)ot​=softmax(V∗st​)

以下是RNN网络结构中的一些细节:

  • 可以把隐状态St看作“记忆体”,它捕捉了之前时间点上的信息。
  • 输出Ot由当前时间以及之前所有的“记忆”共同计算得到。
  • 实际应用中,St并不能捕捉和保留之前的所有信息(记忆有限)。
  • 不同于CNN,在RNN中这个神经网络都共享一组参数(U,V,W),这极大地减少了需要训练的参数量。
  • 图中的Ot在有些任务下是不需要的,比如文本情感分析,其实只需要最后的output结果就行。

循环神经网络的训练-BPTT

如前面我们讲的,如果要预测t时刻的输出,我们必须先利用上一时刻(t-1)的记忆和当前时刻的输入,得到t时刻的记忆:
st=tanh(Uxt+Wst−1) s_t = tanh(Ux_t + Ws_{t-1})st​=tanh(Uxt​+Wst−1​)
然后利用当前时刻的记忆,通过softmax分类器输出每个词出现的概率:
y^t=softmax(Vst) \hat{y}_t = softmax(Vs_t)y^​t​=softmax(Vst​)
为了找出模型最好的参数U,W,V,我们就要知道当前参数得到的结果怎么样,因此就要定义我们的损失函数,用交叉熵损失函数:
t时刻的损失:Et(yt,y^t)=−ytlogy^t t时刻的损失:E_t(y_t,\hat{y}_t) = -y_t log\hat{y}_tt时刻的损失:Et​(yt​,y^​t​)=−yt​logy^​t​
其中yt是t时刻的标准答案,是一个只有一个是1,其他都是0的向量;\hat{y}_t是我们预测出来的结果,与yt的维度一样,但它是一个概率向量,里面是每个词出现的概率。因为对结果的影响,肯定不止一个时刻,因此需要把所有时刻的造成的损失都加起来:
Et(yt,y^t)=−∑tytlogy^t E_t(y_t,\hat{y}_t) = -\sum\limits_t y_tlog\hat{y}_tEt​(yt​,y^​t​)=−t∑​yt​logy^​t​

如图所示,你会发现每个cell都会有一个损失,我们已经定义好了损失函数,接下来就是熟悉的一步了,那就是根据损失函数利用SGD来求解最优参数,在CNN中使用反向传播BP算法来求解最优参数,但在RNN就要用到BPTT,它和BP算法的本质区别,也是CNN和RNN的本质区别:CNN没有记忆功能,它的输出仅依赖与输入,但RNN有记忆功能,它的输出不仅依赖与当前输入,还依赖与当前的记忆。这个记忆是序列到序列的,也就是当前时刻收到上一时刻的影响,比如股市的变化。

因此,在对参数求偏导的时候,对当前时刻求偏导,一定会涉及前一时刻,我们用例子看一下:

假设我们对E3的W求偏导:它的损失首先来源于预测的输出\hat{y}_3,预测的输出又是来源于当前时刻的记忆s3,当前的记忆又是来源于当前的输出和截止到上一时刻的记忆:
s3=tanh(Ux3+Ws2) s_3 = tanh(Ux_3 + Ws_{2}) s3​=tanh(Ux3​+Ws2​)
因此,根据链式法则可以有:
∂E3∂W=∂E3∂y^3∂y^3∂s3∂s3∂W \frac{\partial E_3}{\partial W} = \frac{\partial E_3}{\partial \hat{y}_3} \frac{\partial \hat{y}_3}{\partial s_3} \frac{\partial s_3}{\partial W} ∂W∂E3​​=∂y^​3​∂E3​​∂s3​∂y^​3​​∂W∂s3​​
但是,你会发现,
s2=tanh(Ux2+Ws1) s_2 = tanh(Ux_2 + Ws_{1})s2​=tanh(Ux2​+Ws1​)
也就是s2s里面的函数还包含了W,因此,这个链式法则还没到底,就像图上画的那样,所以真正的链式法则是这样的:


我们要把当前时刻造成的损失,和以往每个时刻造成的损失加起来,因为我们每一个时刻都用到了权重参数W。和以往的网络不同,一般的网络,比如人工神经网络,参数是不同享的,但在循环神经网络,和CNN一样,设立了参数共享机制,来降低模型的计算量。

RNN可以写诗歌,写小说,这里有一个TensorFlow写的例子:
https://github.com/hzy46/Char-RNN-TensorFlow

阅读更多
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: