Attention is all you need pytorch实现 源码解析02 - 模型的训练(1)- 模型的训练代码
2019-02-12 11:25
2779 查看
我们今天继续分析著名的attention is all you need 论文的pytorch实现的源码解析。
由于项目很大,所以我们会分开几讲来进行讲解。
上一讲连接在此:
Attention is all you need pytorch实现 源码解析01 - 数据预处理、词表的构建 - https://blog.csdn.net/weixin_42744102/article/details/87006081
先上github源码:https://github.com/Eathoublu/attention-is-all-you-need-pytorch
项目结构:
-transfomer
—__init__.py
—Beam.py
—Constants.py
—Layers.py
—Models.py
—Module.py
—Optim.py
—SubLayers.py
—Translator.py
今天是第二讲,我们讲一讲模型的训练。模型的训练我将会用两节来讲解,第一节讲的是模型总体的训练的代码(也就是这一节)train.py,下一节我们讲一讲模型的构建以及结构,也就是transformer目录下的Models.py。
下面我们来看一下train.py的源码以及解析:
我使用注释来进行解析,请认真阅读从1到22的注释,不难,希望大家都能看懂。
''' This script handling the training process. ''' import argparse import math import time from tqdm import tqdm import torch import torch.nn.functional as F import torch.optim as optim import torch.utils.data import transformer.Constants as Constants from dataset import TranslationDataset, paired_collate_fn from transformer.Models import Transformer from transformer.Optim import ScheduledOptim # 1 - 从这里开始看,下面我们进入train的main函数,在main函数中我们可以看到-data参数(也就是昨天清洗好的训练集、验证集数据的绝对路径)是一定要传入的。 def main(): ''' Main function ''' parser = argparse.ArgumentParser() parser.add_argument('-data', required=True) parser.add_argument('-epoch', type=int, default=10) parser.add_argument('-batch_size', type=int, default=64) #parser.add_argument('-d_word_vec', type=int, default=512) parser.add_argument('-d_model', type=int, default=512) parser.add_argument('-d_inner_hid', type=int, default=2048) parser.add_argument('-d_k', type=int, default=64) parser.add_argument('-d_v', type=int, default=64) parser.add_argument('-n_head', type=int, default=8) parser.add_argument('-n_layers', type=int, default=6) parser.add_argument('-n_warmup_steps', type=int, default=4000) parser.add_argument('-dropout', type=float, default=0.1) parser.add_argument('-embs_share_weight', action='store_true') parser.add_argument('-proj_share_weight', action='store_true') parser.add_argument('-log', default=None) parser.add_argument('-save_model', default=None) parser.add_argument('-save_mode', type=str, choices=['all', 'best'], default='best') parser.add_argument('-no_cuda', action='store_true') parser.add_argument('-label_smoothing', action='store_true') opt = parser.parse_args() opt.cuda = not opt.no_cuda opt.d_word_vec = opt.d_model # 2 - 导入数据集,用prepare_dataloader函数,导入训练集以及验证集,返回的数据集是torch的Dataloader对象,这样方便一批一批送入进行训练。 #========= Loading Dataset =========# data = torch.load(opt.data) opt.max_token_seq_len = data['settings'].max_token_seq_len training_data, validation_data = prepare_dataloaders(data, opt) opt.src_vocab_size = training_data.dataset.src_vocab_size opt.tgt_vocab_size = training_data.dataset.tgt_vocab_size # 3 - 下面是对一个可选参数的处理,是可以加载词嵌入的共享权重的,我们先不去管它。 #========= Preparing Model =========# if opt.embs_share_weight: assert training_data.dataset.src_word2idx == training_data.dataset.tgt_word2idx, \ 'The src/tgt word2idx table are different but asked to share word embedding.' print(opt) device = torch.device('cuda' if opt.cuda else 'cpu') # 4 - 如果有nvidia显卡,则使用显卡训练,否则cpu transformer = Transformer( # 5 - 构建Transformer模型,这个模型在transformer文件夹的models下面,下面介绍一些参数,这个模型具体是啥样的,我下一节会讲。 opt.src_vocab_size, # 6 - data词表的大小 opt.tgt_vocab_size, # 7 - target词表的大小 opt.max_token_seq_len, # 8 - 最长的句子的长度 tgt_emb_prj_weight_sharing=opt.proj_share_weight, emb_src_tgt_weight_sharing=opt.embs_share_weight, d_k=opt.d_k, d_v=opt.d_v, d_model=opt.d_model, d_word_vec=opt.d_word_vec, d_inner=opt.d_inner_hid, n_layers=opt.n_layers, n_head=opt.n_head, # 注意力模型的头数 dropout=opt.dropout).to(device) # 9 - to(device)的意思是在什么设备上面跑,'cpu'就是用CPU跑,'cuda'就是用GPU optimizer = ScheduledOptim( # 10 - 定义一个优化器 optim.Adam( filter(lambda x: x.requires_grad, transformer.parameters()), betas=(0.9, 0.98), eps=1e-09), opt.d_model, opt.n_warmup_steps) train(transformer, training_data, validation_data, optimizer, device ,opt) # 11 - 调用train函数,进行训练,下面我们看train函数,传入参数有:transformer模型,数据集,优化器等等。 # 12 - 好,现在进入train函数 def train(model, training_data, validation_data, optimizer, device, opt): ''' Start training ''' log_train_file = None log_valid_file = None if opt.log: # 13 - 如果运行的时候在命令行传入了log参数,则会生成日志。 log_train_file = opt.log + '.train.log' log_valid_file = opt.log + '.valid.log' print('[Info] Training performance will be written to file: {} and {}'.format( log_train_file, log_valid_file)) with open(log_train_file, 'w') as log_tf, open(log_valid_file, 'w') as log_vf: log_tf.write('epoch,loss,ppl,accuracy\n') log_vf.write('epoch,loss,ppl,accuracy\n') valid_accus = [] # 14 - 这个列表用于记录验证集的正确率 for epoch_i in range(opt.epoch): # 15 - 使用这一个for循环,将数据送入训练,下面我们具体来看: print('[ Epoch', epoch_i, ']') # 16 - 打印迭代次数 start = time.time() train_loss, train_accu = train_epoch( # 17 - 调用train_epoch函数进行训练,该函数在下面的225行,返回值是训练损失和正确率。 model, training_data, optimizer, device, smoothing=opt.label_smoothing) print(' - (Training) ppl: {ppl: 8.5f}, accuracy: {accu:3.3f} %, '\ 'elapse: {elapse:3.3f} min'.format( ppl=math.exp(min(train_loss, 100)), accu=100*train_accu, elapse=(time.time()-start)/60)) start = time.time() valid_loss, valid_accu = eval_epoch(model, validation_data, device) # 18 - 在验证集上进行检测,返回值是验证的loss和acc print(' - (Validation) ppl: {ppl: 8.5f}, accuracy: {accu:3.3f} %, '\ 'elapse: {elapse:3.3f} min'.format( ppl=math.exp(min(valid_loss, 100)), accu=100*valid_accu, elapse=(time.time()-start)/60)) valid_accus += [valid_accu] # 19 - 每迭代一次,记录当次迭代的正确率,存进一个列表。 model_state_dict = model.state_dict() # 20 - 记录模型的参数。 checkpoint = { 'model': model_state_dict, 'settings': opt, 'epoch': epoch_i} if opt.save_model: # 21 - 将模型持久化保存,如果当时调用的时候传入了这个参数的话。 if opt.save_mode == 'all': model_name = opt.save_model + '_accu_{accu:3.3f}.chkpt'.format(accu=100*valid_accu) torch.save(checkpoint, model_name) elif opt.save_mode == 'best': model_name = opt.save_model + '.chkpt' if valid_accu >= max(valid_accus): torch.save(checkpoint, model_name) print(' - [Info] The checkpoint file has been updated.') if log_train_file and log_valid_file: # 22 - 写入日志文件。好了,到此处,train.py也就是训练的部分已经解析完毕,后面的函数是一些工具,可看可不看,因为在上面都提到过了。 with open(log_train_file, 'a') as log_tf, open(log_valid_file, 'a') as log_vf: log_tf.write('{epoch},{loss: 8.5f},{ppl: 8.5f},{accu:3.3f}\n'.format( epoch=epoch_i, loss=train_loss, ppl=math.exp(min(train_loss, 100)), accu=100*train_accu)) log_vf.write('{epoch},{loss: 8.5f},{ppl: 8.5f},{accu:3.3f}\n'.format( epoch=epoch_i, loss=valid_loss, ppl=math.exp(min(valid_loss, 100)), accu=100*valid_accu)) def prepare_dataloaders(data, opt): # ========= Preparing DataLoader =========# train_loader = torch.utils.data.DataLoader( TranslationDataset( src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=data['train']['src'], tgt_insts=data['train']['tgt']), num_workers=2, batch_size=opt.batch_size, collate_fn=paired_collate_fn, shuffle=True) valid_loader = torch.utils.data.DataLoader( TranslationDataset( src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=data['valid']['src'], tgt_insts=data['valid']['tgt']), num_workers=2, batch_size=opt.batch_size, collate_fn=paired_collate_fn) return train_loader, valid_loader def cal_performance(pred, gold, smoothing=False): ''' Apply label smoothing if needed ''' loss = cal_loss(pred, gold, smoothing) pred = pred.max(1)[1] gold = gold.contiguous().view(-1) non_pad_mask = gold.ne(Constants.PAD) n_correct = pred.eq(gold) n_correct = n_correct.masked_select(non_pad_mask).sum().item() return loss, n_correct def cal_loss(pred, gold, smoothing): ''' Calculate cross entropy loss, apply label smoothing if needed. ''' gold = gold.contiguous().view(-1) if smoothing: eps = 0.1 n_class = pred.size(1) one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) log_prb = F.log_softmax(pred, dim=1) non_pad_mask = gold.ne(Constants.PAD) loss = -(one_hot * log_prb).sum(dim=1) loss = loss.masked_select(non_pad_mask).sum() # average later else: loss = F.cross_entropy(pred, gold, ignore_index=Constants.PAD, reduction='sum') return loss def train_epoch(model, training_data, optimizer, device, smoothing): ''' Epoch operation in training phase''' model.train() #补充:这句话相当于一个初始化模型的功能,而并非训练 total_loss = 0 n_word_total = 0 n_word_correct = 0 for batch in tqdm( training_data, mininterval=2, desc=' - (Training) ', leave=False): # prepare data src_seq, src_pos, tgt_seq, tgt_pos = map(lambda x: x.to(device), batch) gold = tgt_seq[:, 1:] # forward optimizer.zero_grad() pred = model(src_seq, src_pos, tgt_seq, tgt_pos) # backward loss, n_correct = cal_performance(pred, gold, smoothing=smoothing) loss.backward() # update parameters optimizer.step_and_update_lr() # note keeping total_loss += loss.item() non_pad_mask = gold.ne(Constants.PAD) n_word = non_pad_mask.sum().item() n_word_total += n_word n_word_correct += n_correct loss_per_word = total_loss/n_word_total accuracy = n_word_correct/n_word_total return loss_per_word, accuracy def eval_epoch(model, validation_data, device): ''' Epoch operation in evaluation phase ''' model.eval() total_loss = 0 n_word_total = 0 n_word_correct = 0 with torch.no_grad(): for batch in tqdm( validation_data, mininterval=2, desc=' - (Validation) ', leave=False): # prepare data src_seq, src_pos, tgt_seq, tgt_pos = map(lambda x: x.to(device), batch) gold = tgt_seq[:, 1:] # forward pred = model(src_seq, src_pos, tgt_seq, tgt_pos) loss, n_correct = cal_performance(pred, gold, smoothing=False) # note keeping total_loss += loss.item() non_pad_mask = gold.ne(Constants.PAD) n_word = non_pad_mask.sum().item() n_word_total += n_word n_word_correct += n_correct loss_per_word = total_loss/n_word_total accuracy = n_word_correct/n_word_total return loss_per_word, accuracy if __name__ == '__main__': main()
由于本人水平有限,其中不免会出现疏漏以及错误,烦请大家向我踊跃提出,本人将感激不尽,并将会在最快的时间内予以修正,谢谢大家!本人工作邮箱:1012950361@qq.com
敬请关注下一讲:Attention is all you need pytorch实现 源码解析03 - 模型的训练(2)- transformer模型构建的源代码解析
相关文章推荐
- Attention is all you need pytorch实现 源码解析03 - 模型的训练(2)- transformer模型的代码实现以及结构
- Attention is all you need pytorch实现 源码解析04 - 模型的测试以及翻译
- 一文读懂「Attention is All You Need」| 附代码实现
- 一文读懂「Attention is All You Need」| 附代码实现
- 《Attention is All You Need》浅读(简介+代码)
- 模型汇总16 各类Seq2Seq模型对比及《Attention Is All You Need》中技术详解
- 《Attention is All You Need》浅读(简介+代码)
- 【论文阅读】Attention Is All You Need
- Attention Is All You Need
- 论文阅读笔记之Attention Is All You Need
- NMT十篇必读论文(一)attention is all you need
- 对Attention is all you need 的理解
- Attention is all you need新翻译架构的测试
- 谷歌机器翻译Attention is All You Need
- Attention is all you need 论文记录
- Attention is all you need阅读笔记
- Attention Is All You Need
- Attention Is All You Need 论文阅读笔记
- Attention Is All You Need读后感
- UVa 10193 - All You Need Is Love