您的位置:首页 > 其它

LSTM文本分类(tensorflow)

2018-01-16 10:54 120 查看
1)LSTM介绍

转载自https://www.csdn.net/article/2015-09-14/2825693

Gates:



输入变换:



状态更新:



使用图片描述类似下图:



输入

首先,让我们来定义输入形式。在lua中类似数组的对象称为表,这个网络将接受一个类似下面的这个张量表。

local inputs = {}
table.insert(inputs, nn.Identity()())  -- network input
table.insert(inputs, nn.Identity()())  -- c at time t-1
table.insert(inputs, nn.Identity()())  -- h at time t-1
local input = inputs[1]
local prev_c = inputs[2]
local prev_h = inputs[3]


计算gate值

locali2h=nn.Linear(input_size,4*rnn_size)(input)-- input to hiddenlocalh2h=nn.Linear(rnn_size,4*rnn_size)(prev_h)-- hidden to hiddenlocalpreactivations=nn.CAddTable()({i2h,h2h})-- i2h + h2h




运用非线性

-- gates
localpre_sigmoid_chunk=nn.Narrow(2,1,3*rnn_size)(preactivations)
localall_gates=nn.Sigmoid()(pre_sigmoid_chunk)
-- input
localin_chunk=nn.Narrow(2,3*rnn_size+1,rnn_size)(preactivations)
localin_transform=nn.Tanh()(in_chunk)


在非线性操作之后,我们需要增加更多的nn.Narrow,然后我们就完成了gates。

localin_gate=nn.Narrow(2,1,rnn_size)(all_gates)
localforget_gate=nn.Narrow(2,rnn_size+1,rnn_size)(all_gates)
localout_gate=nn.Narrow(2,2*rnn_size+1,rnn_size)(all_gates)




计算当前的Cell状态

-- previous cell state contribution
localc_forget=nn.CMulTable()({forget_gate,prev_c})
-- input contribution
localc_input=nn.CMulTable()({in_gate,in_transform})
-- next cell state
localnext_c=nn.CAddTable()({
c_forget,
c_input
})


实现hidden 状态计算

localc_transform=nn.Tanh()(next_c)
localnext_h=nn.CMulTable()({out_gate,c_transform})




实例:http://apaszke.github.io/assets/posts/lstm-explained/multilayer.lua

2)lstm实现文本分类

转载自 http://blog.csdn.net/u010223750/article/details/53334313?locationNum=7&fps=1

2.1原理



简单解释一下这个图,每个word经过embedding之后,进入LSTM层,这里LSTM是标准的LSTM,然后经过一个时间序列得到的t个隐藏LSTM神经单元的向量,这些向量经过mean pooling层之后,可以得到一个向量h,然后紧接着是一个简单的逻辑斯蒂回归层(或者一个softmax层)得到一个类别分布向量。

2.2tensorflow基础

a) variable_scope

1. 使用tf.Variable()的时候,tf.name_scope()和tf.variable_scope() 都会给 Variable 和 op 的 name属性加上前缀。
2. 使用tf.get_variable()的时候,tf.name_scope()就不会给 tf.get_variable()创建出来的Variable加前缀。


b) tf.nn.embedding_lookup

tf.nn.embedding_lookup函数的用法主要是选取一个张量里面索引对应的元素。tf.nn.embedding_lookup(tensor, id):tensor就是输入张量,id就是张量对应的索引


c) tf.device

比如第一个GPU的名称为/gpu:0,第二个GPU名称为/gpu:1,以此类推。


d) variable转为python int

参见https://www.tensorflow.org/api_docs/python/tf/to_float

tf.to_int32(tf.variable)
tf.to_int64(tf.variable)
tf.to_float(tf.varialbe)


e)tf.nn.sparse_softmax_cross_entropy_with_logits

错误tensorflow:Only call
sparse_softmax_cross_entropy_with_logits
with named arguments

解决办法如下:

tf.nn.sparse_softmax_cross_entropy_with_logits(logits, train_labels_node))

改为tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=train_labels_node)

2.3代码

转载自 https://github.com/luchi007/RNN_Text_Classify

import tensorflow as tf
import numpy as np

class RNN_Model(object):

def __init__(self,config,is_training=True):

self.keep_prob=config.keep_prob
self.batch_size=tf.Variable(0,dtype=tf.int32,trainable=False)

num_step=config.num_step
self.input_data=tf.placeholder(tf.int32,[None,num_step])
self.target = tf.placeholder(tf.int64,[None])
self.mask_x = tf.placeholder(tf.float32,[num_step,None])

class_num=config.class_num
hidden_neural_size=config.hidden_neural_size
vocabulary_size=config.vocabulary_size
embed_dim=config.embed_dim
hidden_layer_num=config.hidden_layer_num
self.new_batch_size = tf.placeholder(tf.int32,shape=[],name="new_batch_size")
self._batch_size_update = tf.assign(self.batch_size,self.new_batch_size)

#build LSTM network
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_neural_size,forget_bias=0.0,state_is_tuple=True)
if self.keep_prob<1:
lstm_cell =  tf.nn.rnn_cell.DropoutWrapper(
lstm_cell,output_keep_prob=self.keep_prob
)

cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell]*hidden_layer_num,state_is_tuple=True)

self._initial_state = cell.zero_state(tf.to_int32(self.batch_size),dtype=tf.float32)

#embedding layer
with tf.device("/gpu:0"),tf.name_scope("embedding_layer"):
embedding = tf.get_variable("embedding",[vocabulary_size,embed_dim],dtype=tf.float32)
inputs=tf.nn.embedding_lookup(embedding,self.input_data)

if self.keep_prob<1:
inputs = tf.nn.dropout(inputs,self.keep_prob)

out_put=[]
state=self._initial_state
with tf.variable_scope("LSTM_layer"):
for time_step in range(num_step):
if time_step>0: tf.get_variable_scope().reuse_variables()
(cell_output,state)=cell(inputs[:,time_step,:],state)
out_put.append(cell_output)

out_put=out_put*self.mask_x[:,:,None]

with tf.name_scope("mean_pooling_layer"):

out_put=tf.reduce_sum(out_put,0)/(tf.reduce_sum(self.mask_x,0)[:,None])

with tf.name_scope("Softmax_layer_and_output"):
softmax_w = tf.get_variable("softmax_w",[hidden_neural_size,class_num],dtype=tf.float32)
softmax_b = tf.get_variable("softmax_b",[class_num],dtype=tf.float32)
self.logits = tf.matmul(out_put,softmax_w)+softmax_b

with tf.name_scope("loss"):
self.loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits+1e-10,labels=self.target)
self.cost = tf.reduce_mean(self.loss)

with tf.name_scope("accuracy"):
self.prediction = tf.argmax(self.logits,1)
correct_prediction = tf.equal(self.prediction,self.target)
self.correct_num=tf.reduce_sum(tf.cast(correct_prediction,tf.float32))
self.accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32),name="accuracy")

if not is_training:
return

self.globle_step = tf.Variable(0,name="globle_step",trainable=False)
self.lr = tf.Variable(0.0,trainable=Fa
b83b
lse)

tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars),
config.max_grad_norm)
optimizer = tf.train.GradientDescentOptimizer(self.lr)
optimizer.apply_gradients(zip(grads, tvars))
self.train_op=optimizer.apply_gradients(zip(grads, tvars))

self.new_lr = tf.placeholder(tf.float32,shape=[],name="new_learning_rate")
self._lr_update = tf.assign(self.lr,self.new_lr)

def assign_new_lr(self,session,lr_value):
session.run(self._lr_update,feed_dict={self.new_lr:lr_value})
def assign_new_batch_size(self,session,batch_size_value):
session.run(self._batch_size_update,feed_dict={self.new_batch_size:batch_size_value})


2.3结果



2.4此代码局限

a)二分类
b)基于标题的分类
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  文本分类