【tensorflow 学习】seq2seq模型代码解读
2017-08-22 21:28
513 查看
1. sequence-to-sequence模型
官方教程使用seq2seq模型实现了英语-法语的翻译系统。经典的sequence-to-sequence模型由两个RNN网络构成,一个被称为“encoder”,另一个则称为“decoder”,前者负责把序列编码成一个固定长度的向量,这个向量作为输入传给后者,输出可变长度的向量,它的基本网络结构如下,其中每一个小圆圈代表一个cell,比如GRUcell、LSTMcell、multi-layer-GRUcell、multi-layer-GRUcell等。尽管“encoder”或者“decoder”内部存在权值共享,但encoder和decoder之间一般具有不同的一套参数。
2. 注意力机制(attention mechanism)
考虑到encoder将输入seq编码成一个向量,基本的decoder与encoder的交互仅在decoder初始的输入上,这样对于decoder而言,只能看到源信息的一个总体概要,会限制encoder-decoder架构的性能。基于这个缺点进行了改进, 在翻译阶段, 准备生成每个新的词的时候, 注意力机制可以将注意力集中在输入的某个或某几个词上,重点关注这几个词, 使得翻译更精准。新的模型架构如下图所示。
3. seq2seq_model.py 解读
这个模型依照 http://arxiv.org/abs/1412.7449 实现。机器学习模型的定义过程,一般包括输入变量定义、输入信息的forward propagation和误差信息的backward propagation三个部分,这三个部分在这个程序文件中都得到了很好的体现,下面我们结合代码分别进行介绍。
先来看一下seq2seq函数:
def seq2seq_f(encoder_inputs, decoder_inputs, do_decode): return tf.contrib.legacy_seq2seq.embedding_attention_seq2seq( encoder_inputs,# tensor of input seq decoder_inputs,# tensor of decoder seq cell, #自定义的cell,可以是GRU/LSTM, 设置multilayer等 num_encoder_symbols=source_vocab_size, # 英语词典大小 40000 num_decoder_symbols=target_vocab_size, # 法语词典大小 40000 embedding_size=size, # embedding 维度 output_projection=output_projection, # 不设定的话输出维数可能很大(取决于词表大小),设定的话投影到一个低维向量 feed_previous=do_decode, # false: 训练 ;True: 测试 dtype=dtype)
3.1 输入变量的定义
# Feeds for inputs. self.encoder_inputs = [] self.decoder_inputs = [] self.target_weights = [] for i in xrange(buckets[-1][0]): # Last bucket is the biggest one. self.encoder_inputs.append(tf.placeholder(tf.int32, shape=[None], name="encoder{0}".format(i))) #encoder_inputs 这个列表对象中的每一个元素表示一个占位符,其名字分别为encoder0, encoder1,…,encoder39,encoder{i}的几何意义是编码器在时刻i的输入。 for i in xrange(buckets[-1][1] + 1): self.decoder_inputs.append(tf.placeholder(tf.int32, shape=[None], name="decoder{0}".format(i))) self.target_weights.append(tf.placeholder(dtype, shape=[None], name="weight{0}".format(i))) #target_weights 是一个与 decoder_outputs 大小一样的 0-1 矩阵。该矩阵将目标序列长度以外的其他位置填充为标量值 0。 # Our targets are decoder inputs shifted by one. targets = [self.decoder_inputs[i + 1] for i in xrange(len(self.decoder_inputs) - 1)] # 跟language model类似,targets变量是decoder inputs平移一个单位的结果,
3.2 输入信息的forward propagation
# 区别在于seq2seq_f函数的参数feed previous是True还是false if forward_only: # 测试阶段 self.outputs, self.losses = tf.contrib.legacy_seq2seq.model_with_buckets(#?? self.encoder_inputs, self.decoder_inputs, targets, self.target_weights, buckets, lambda x, y: seq2seq_f( x, y, True), softmax_loss_function=softmax_loss_function) # If we use output projection, we need to project outputs for # decoding. if output_projection is not None: for b in xrange(len(buckets)): self.outputs[b] = [ tf.matmul(output, output_projection[ 0]) + output_projection[1] for output in self.outputs[b] ] else:#训练阶段 self.outputs, self.losses = tf.contrib.legacy_seq2seq.model_with_buckets( self.encoder_inputs, self.decoder_inputs, targets, self.target_weights, buckets, lambda x, y: seq2seq_f(x, y, False), softmax_loss_function=softmax_loss_function)
从代码中可以看到,输入信息的forward popagation分成了两种情况,这是因为整个sequence to sequence模型在训练阶段和测试阶段信息的流向是不一样的,这一点可以从seq2seqf函数的do_decode参数值体现出来,而do_decoder取值对应的就是tf.nn.seq2seq.embedding_attention_seq2seq函数中的feed_previous参数,forward_only为True也即feed_previous参数为True时进行模型测试,为False时进行模型训练。
3.3 误差信息的backward propagation
params = tf.trainable_variables() if not forward_only:# 只有训练阶段才需要计算梯度和参数更新 self.gradient_norms = [] self.updates = [] opt = tf.train.GradientDescentOptimizer(self.learning_rate) # 用梯度下降法优化 for b in xrange(len(buckets)): gradients = tf.gradients(self.losses[b], params) #计算损失函数关于参数的梯度 clipped_gradients, norm = tf.clip_by_global_norm(gradients, max_gradient_norm)# clip gradients 防止梯度爆炸 self.gradient_norms.append(norm) self.updates.append(opt.apply_gradients( zip(clipped_gradients, params), global_step=self.global_step))#更新参数
这一段代码主要用于计算损失函数关于参数的梯度。因为只有训练阶段才需要计算梯度和参数更新,所以这里有个if判断语句。并且,由于当前定义除了length(buckets)个graph,故返回值self.updates是一个列表对象,尺寸为length(buckets),列表中第i个元素表示graph{i}的梯度更新操作。
训练 RNN 的一个重要步骤是梯度截断(gradient clipping)。这里,我们使用全局范数进行截断操作。最大值
max_gradient_norm通常设置为 5 或 1。
3.4 模型训练
# Input feed: encoder inputs, decoder inputs, target_weights, as # provided. input_feed = {} for l in xrange(encoder_size): input_feed[self.encoder_inputs[l].name] = encoder_inputs[l] for l in xrange(decoder_size): input_feed[self.decoder_inputs[l].name] = decoder_inputs[l] input_feed[self.target_weights[l].name] = target_weights[l] # Since our targets are decoder inputs shifted by one, we need one # more. last_target = self.decoder_inputs[decoder_size].name input_feed[last_target] = np.zeros([self.batch_size], dtype=np.int32) # Output feed: depends on whether we do a backward step or not. if not forward_only: output_feed = [self.updates[bucket_id], # Update Op that does SGD. self.gradient_norms[bucket_id], # Gradient norm. self.losses[bucket_id]] # Loss for this batch. else: output_feed = [self.losses[bucket_id]] # Loss for this batch. for l in xrange(decoder_size): # Output logits. output_feed.append(self.outputs[bucket_id][l]) outputs = session.run(output_feed, input_feed) if not forward_only: # Gradient norm, loss, no outputs. return outputs[1], outputs[2], None else: # No gradient norm, loss, outputs. return None, outputs[0], outputs[1:]
模型已经定义完成了,这里便开始进行模型训练了。上面的两个for循环用于为之前定义的输入占位符赋予具体的数值,这些具体的数值源自于get_batch函数的返回值。当session.run函数开始执行时,当前session会对第bucket_id个graph进行参数更新操作。
Reference:
1. https://www.tensorflow.org/tutorials/seq2seq
2. http://www.2cto.com/kf/201612/575911.html
3. http://www.jianshu.com/p/58ef2b990d3f
相关文章推荐
- tensorflow学习(3):解读mnist_experts例子,训练保存模型并tensorboard可视化
- Tensorflow学习教程------读取数据、建立网络、训练模型,小巧而完整的代码示例
- Tensorflow-slim 学习笔记(二)第一层目录代码解读
- Tensorflow-slim 学习笔记(二)第一层目录代码解读
- Tensorflow学习:ResNet代码(详细剖析)-待补充,非最终版本
- 深度学习之卷积神经网络CNN及tensorflow代码实现示例
- Keras 深度学习代码笔记——模型保存与加载
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- 【deep learning学习笔记】注释yusugomori的RBM代码 --- cpp文件 -- 模型测试
- 深度学习利器:TensorFlow在智能终端中的应用——智能边缘计算,云端生成模型给移动端下载,然后用该模型进行预测
- [caffe]深度学习之图像分类模型VGG解读
- TensorFlow 深度学习框架(9)-- 经典卷积网络模型 : LeNet-5 模型 & Inception-v3 模型
- tensorflow学习笔记1(代码转自官网)
- Tensorflow学习教程------模型参数和网络结构保存且载入,输入一张手写数字图片判断是几
- tensorflow的一些代码分析(五) tensorflow模型保存和可视化
- TensorFlow 深度学习框架(6)-- mnist 数字识别及不同模型效果比较
- Tensorflow学习笔记-构建网络模型
- LeNet-5模型详解及其TensorFlow代码实现
- 【深度学习】【Caffe源代码解读4】笔记22 Caffe框架中I/O模块的代码初探
- 序列到序列的语言翻译模型代码(tensorflow)解析