您的位置:首页 > 其它

LSTM模型简介及Tensorflow实现

2017-11-14 17:21 351 查看
LSTM模型在RNN模型的基础上新增加了单元状态C(cell state)。

一. 模型的输入和输出

在t时刻,LSTM的输入有3个:

(1) 当前时刻LSTM的输入值x(t);

(2) 上一时刻LSTM的输出值h(t-1);

(3) 上一时刻的单元状态c(t-1);

LSTM的输出有2个:

(1) 当前时刻LSTM的输出值h(t);

(2) 当前时刻的单元状态c(t);

二. 模型的计算



(1) 遗忘门:forget gate,控制上一时刻的单元状态有多少传入:



(2) 输入门:input gate,控制上一时刻LSTM的输出有多少传入:



(3) 当前时刻输入的单元状态:



(4) 当前时刻LSTM的单元状态:



(5) 输出门:output gate,控制有多少传入到LSTM当前时刻的输出:



(6) 当前时刻LSTM的输出:



note:公式中的X表示对应元素相乘;

三. TensorFlow实现LSTM-regression模型

# load module
from tensorflow.example.tutorial.mmist import input_data
import tensorflow as tf
import numpy as np

# definite hyperparameters
BATCH_SIZE = 64
TIME_STEP = 28
INPUT_SIZE = 28
LR = 0.01

# load data
mnist = input_data.read_data_sets('mnist', one_hot=True)

# test data
t
4000
est_x = mnist.test.images[:2000]
test_y = mnist.test.labels[:2000]

# placeholder
tf_x = tf.placeholder(tf.float32, [None, TIME_STEP * INPUT_SIZE])
image = tf.reshape(tf_x, [-1, TIME_STEP, INPUT_SIZE])
tf_y = tf.placeholder(tf.int32, [None, 10])

# RNN
rnn_cell = tf.contrib.rnn.BasicLSTMCell(num_units=64)
outputs, (h_c, h_n) = tf.nn.dynamic_rnn(rnn_cell, image, dtype=tf.float32)
loss = tf.losses.softmax_cross_entropy(onehot_labels=tf_y, logits=output)
train_op = tf.train.AdamOptimizer(LR).minimize(loss)
accuracy = tf.metrics.accuracy(labels=tf.argmax(tf_y, axis=1), predictions=tf.argmax(output, axis=1),)[1]

# open an tf session
sess = tf.Session()
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init_op)

# train
for step in range(1200):
b_x, b_y = mnist.train.next_batch(BATCH_SIZE)
_, loss_ = sess.run([train_op, loss], {tf_x: b_x, tf_y: b_y})
if step % 50 == 0:
accuracy_ = sess.run(accuracy, {tf_x: test_x, tf_y: test_y})
print('train loss: %.4f' % loss_, '| test accuracy: %.2f' % accuracy_)

test_output = sess.run(output, {tf_x: test_x[: 10]})
pred_y = np.argmax(test_output, 1)
print(pred_y, 'prediction_number')
print(np.argmax(test_y[: 10], 1), 'real number')


四. 参考

(1) 韩炳涛系列文章:https://www.zybuluo.com/hanbingtao/note/581764

(2) 莫烦系列教程: https://github.com/MorvanZhou/Tensorflow-Tutorial/tree/master/tutorial-contents
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: