lua torch实现ST-LSTM
2017-12-29 12:20
344 查看
参考文章[]https://www.csdn.net/article/2015-09-14/2825693]
原文链接[]https://apaszke.github.io/lstm-explained.html]
普通LSTM上一状态的输入只有c和h,而ST-LSTM分成两部分,t和j
将输入表中的元素分别送给prev_cj,prev_hj,prev_ct,prev_ht
原文链接[]https://apaszke.github.io/lstm-explained.html]
1.定义输入
-- there will be 4*n+1 inputs local inputs = {} #创建一个空的table table.insert(inputs, nn.Identity()()) -- x for L = 1, n do # nn.Identity() - 传递输入(用来存放输入数据) table.insert(inputs, nn.Identity()()) -- prev_cj[L] table.insert(inputs, nn.Identity()()) -- prev_hj[L] end for L = 1, n do table.insert(inputs, nn.Identity()()) -- prev_ct[L] table.insert(inputs, nn.Identity()()) -- prev_ht[L] end local x, input_size_L
普通LSTM上一状态的输入只有c和h,而ST-LSTM分成两部分,t和j
将输入表中的元素分别送给prev_cj,prev_hj,prev_ct,prev_ht
local outputs = {} for L = 1, n do -- c,h from previos steps local prev_cj = inputs[L*2] local prev_hj = inputs[L*2+1] local prev_ct = inputs[n*2+L*2] local prev_ht = inputs[n*2+L*2+1] -- the input to this layer if (L == 1) then x = inputs[1] input_size_L = input_size else x = outputs[(L-1)*2] if dropout > 0 then x = nn.Dropout(dropout)(x) end -- apply dropout, if any input_size_L = rnn_size end
2.输入线性变换
rnn_size是hideen units个数,线性变换后,拆成5部分(普通LSTM是四部分)-- evaluate the input sums at once for efficiency local i2h = nn.Linear(input_size_L, 5 * rnn_size)(x):annotate{ name = 'i2h_' .. L} local h2hj = nn.Linear(rnn_size, 5 * rnn_size)(prev_hj):annotate{name = 'h2hj_' .. L} local h2ht = nn.Linear(rnn_size, 5 * rnn_size)(prev_ht):annotate{name = 'h2ht_' .. L} local all_input_sums = nn.CAddTable()({i2h, h2hj, h2ht}) local reshaped = nn.Reshape(5, rnn_size)(all_input_sums) local n1, n2, n3, n4, n5 = nn.SplitTable(2)(reshaped):split(5)
3. 输入非线性变换
-- decode the gates local in_gate = nn.Sigmoid()(n1) local forget_gate_j = nn.Sigmoid()(n2) local forget_gate_t = nn.Sigmoid()(n3) local out_gate = nn.Sigmoid()(n4) -- decode the write inputs local in_transform = nn.Tanh()(n5)
4.状态更新
local next_c = nn.CAddTable()({ nn.CMulTable()({forget_gate_j, prev_cj}), nn.CMulTable()({forget_gate_t, prev_ct}), nn.CMulTable()({in_gate, in_transform}) }) -- gated cells form the output local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)}) table.insert(outputs, next_c) table.insert(outputs, next_h)
相关文章推荐
- torch学习笔记3.1:实现自定义模块(lua)
- Lua实现LSTM 前向传播
- pytorch+lstm实现的pos
- Lua里实现将table转成字符串(序列化)和将字符串转换回table(反序列化)
- lua实现面向对象的特性
- Lua实现深度拷贝(Deep Copy)
- nginx利用lua实现nginx反向代理proxy_store缓存文件自删除
- torch入门笔记4:用torch实现MNIST手写数字识别
- tensorflow 学习专栏(七):使用RNN (LSTM)实现mnist手写数据集分类
- 寿星万年历---Lua实现
- lua以xpcall实现try/catch功能
- lua实现深度拷贝table表
- Lua的系统学习(杂)_通过Lua调用C#方法(热更新最直观的原理理解)_简单的随机数实现
- lua脚本语言的学习-----------------如何实现c++无参数的函数在lua中调用
- 用Lua实现插入、删除和查找时间复杂度为O(1)的集合
- 利用CEGUI+Lua实现灵活的游戏UI框架
- 一个纯C#的Lua 5.2实现
- 利用ST提供的USB例程实现USB IAP功能
- lua的packages实现
- lua的面向对象实现