您的位置:首页 > 理论基础 > 计算机网络

Task03 循环神经网络进阶(pytorch代码实现)

2020-03-05 19:09 537 查看

循环神经网络进阶

⻔控循环神经⽹络(GRU)

当时间步数较⼤或者时间步较小时, 循环神经⽹络的梯度较容易出现衰减或爆炸。虽然裁剪梯度可以应对梯度爆炸,但⽆法解决梯度衰减的问题。通常由于这个原因,循环神经⽹络在实际中较难捕捉时间序列中时间步距离较⼤的依赖关系。

⻔控循环神经⽹络(GRU):捕捉时间序列中时间步距离较⼤的依赖关系

CNN:

GRU:

• 重置⻔有助于捕捉时间序列⾥短期的依赖关系;
• 更新⻔有助于捕捉时间序列⾥⻓期的依赖关系。

GRU pytorch简洁代码实现

import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F
import sys
sys.path.append(".")
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()

num_hiddens=256
num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32, 1e2, 1e-2
pred_period, pred_len, prefixes = 40, 50, ['分开', '不分开']

lr = 1e-2 # 注意调整学习率
gru_layer = nn.GRU(input_size=vocab_size, hidden_size=num_hiddens)
model = d2l.RNNModel(gru_layer, vocab_size).to(device)
d2l.train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,
corpus_indices, idx_to_char, char_to_idx,
num_epochs, num_steps, lr, clipping_theta,
batch_size, pred_period, pred_len, prefixes)

长短期记忆(LSTM)

LSTM是比GRU更加复杂一点的门控循环单元,它引入了3个门和一个记忆细胞的概念;
•输入门:控制当前时间步的输入;
•遗忘门:控制上一时间步的记忆细胞 ;
•输出门:控制从记忆细胞到隐藏状态;
•记忆细胞:⼀种特殊的隐藏状态的信息的流动


其中公式为:
It=σ(XtWxi+Ht−1Whi+bi)
Ft=σ(XtWxf+Ht−1Whf+bf)
Ot=σ(XtWxo+Ht−1Who+bo)
C˜t=tanh(XtWxc+Ht−1Whc+bc)
Ct=Ft⊙Ct−1+It⊙C˜t
Ht=Ot⊙tanh(Ct)

LSTM的pytorch简洁实现

import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F
import sys
sys.path.append(".")
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()
num_hiddens=256
num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32, 1e2, 1e-2
pred_period, pred_len, prefixes = 40, 50, ['分开', '不分开']

lr = 1e-2 # 注意调整学习率
lstm_layer = nn.LSTM(input_size=vocab_size, hidden_size=num_hiddens)
model = d2l.RNNModel(lstm_layer, vocab_size)
d2l.train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,
corpus_indices, idx_to_char, char_to_idx,
num_epochs, num_steps, lr, clipping_theta,
batch_size, pred_period, pred_len, prefixes)

深度循环神经网络

深度循环神经网络的pytorch代码

import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F
import sys
sys.path.append(".")
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()

num_hiddens=256
num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32, 1e2, 1e-2
pred_period, pred_len, prefixes = 40, 50, ['分开', '不分开']

lr = 1e-2 # 注意调整学习率
#其中num_layers=2 为网络层数
gru_layer = nn.LSTM(input_size=vocab_size, hidden_size=num_hiddens,num_layers=2)
model = d2l.RNNModel(gru_layer, vocab_size).to(device)
d2l.train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,
corpus_indices, idx_to_char, char_to_idx,
num_epochs, num_steps, lr, clipping_theta,
batch_size, pred_period, pred_len, prefixes)

双向循环神经网络

num_hiddens=128
num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32, 1e-2, 1e-2
pred_period, pred_len, prefixes = 40, 50, ['分开', '不分开']

lr = 1e-2 # 注意调整学习率
#通过参数bidirectional=True来进行控制
gru_layer = nn.GRU(input_size=vocab_size, hidden_size=num_hiddens,bidirectional=True)
model = d2l.RNNModel(gru_layer, vocab_size).to(device)
d2l.train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,
corpus_indices, idx_to_char, char_to_idx,
num_epochs, num_steps, lr, clipping_theta,
batch_size, pred_period, pred_len, prefixes)

小结

※ 门控循环神经网络 GRU
•GRU可以更好的捕捉时间序列中时间步距离较大的依赖关系
•GRU单元引入了门的概念,从而修改了循环神经网络中隐藏状态的计算方式。它包括重置门、更新门、候选隐藏状态和隐藏状态
•重置门有助于捕捉时间序列里短期的依赖关系
•更新门有助于捕捉时间序列里长期的依赖关系

※长短期记忆LSTM
•长短期记忆的隐藏层输出包括隐藏状态和记忆细胞。只有隐藏状态会传递到输出层
•长短期记忆的输入门、遗忘门和输出门可以控制信息流动
•长短期记忆可以应对循环神经网络中的梯度衰减问题,并更好的捕捉时间序列中时间步距离较大的依赖关系

深度循环神经网络
•在深度循环神经网络中,隐藏状态的信息不断传递至当前层的下一时间步和当前时间步的下一层

双向循环神经网络
双向循环神经网络在每个时间步的隐藏状态同事取决于改时间步和之前和之后的值序列(包括当前时间步的输入)

  • 点赞
  • 收藏
  • 分享
  • 文章举报
l_yiyu 发布了7 篇原创文章 · 获赞 0 · 访问量 444 私信 关注
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐