您的位置:首页 > 其它

pytorch加载和保存模型

2017-10-12 16:16 441 查看

在模型完成训练后,我们需要将训练好的模型保存为一个文件供测试使用,或者因为一些原因我们需要继续之前的状态训练之前保存的模型,那么如何在PyTorch中保存和恢复模型呢?

方法一(推荐):

第一种方法也是官方推荐的方法,只保存和恢复模型中的参数。

保存    

torch.save(the_model.state_dict(), PATH)

恢复

the_model = TheModelClass(*args, **kwargs) the_model.load_state_dict(torch.load(PATH))

使用这种方法,我们需要自己导入模型的结构信息。

方法二:

使用这种方法,将会保存模型的参数和结构信息。

保存

torch.save(the_model, PATH)

恢复

the_model = torch.load(PATH)

一个相对完整的例子

saving

torch.save({ 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, 'checkpoint.tar' )

loading

if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.evaluate, checkpoint['epoch']))  

获取模型中某些层的参数

对于恢复的模型,如果我们想查看某些层的参数,可以:

# 定义一个网络 from collections import OrderedDict model = nn.Sequential(OrderedDict([ ('conv1', nn.Conv2d(1,20,5)), ('relu1', nn.ReLU()), ('conv2', nn.Conv2d(20,64,5)), ('relu2', nn.ReLU()) ])) # 打印网络的结构 print(model)   OUT: Sequential ( (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1)) (relu1): ReLU () (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1)) (relu2): ReLU () )   如果我们想获取conv1的weight和bias:   params=model.state_dict() for k,v in params.items(): print(k) #打印网络中的变量名 print(params['conv1.weight']) #打印conv1的weight print(params['conv1.bias']) #打印conv1的bias  


内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: