pytorch finetune模型
2017-11-28 14:25
651 查看
pytorch finetune模型
文章主要讲述如何在pytorch上读取以往训练的模型参数,在模型的名字已经变更的情况下又如何读取模型的部分参数等。——–作者:jiangwenj02【转载请注明】
pytorch 模型的存储与读取
其中在模型的保存过程有存储模型和参数一起的也有单独存储模型参数的单独存储模型参数
存储时使用:torch.save(the_model.state_dict(), PATH)
读取时:
the_model = TheModelClass(*args, **kwargs) the_model.load_state_dict(torch.load(PATH))
存储模型与参数
存储:torch.save(the_model, PATH)
读取:
the_model = torch.load(PATH)
模型的参数
fine-tune的过程是读取原有模型的参数,但是由于模型的所要处理的数据集不同,最后的一层class的总数不同,所以需要修改模型的最后一层,这样模型读取的参数,和在大数据集上训练好下载的模型参数在形式上不一样。需要我们自己去写函数读取参数。pytorch模型参数的形式
模型的参数是以字典的形式存储的。model_dict = the_model.state_dict(), for k,v in model_dict.items(): print(k)
即可看到所有的键值
如果想修改模型的参数,给相应的键值赋值即可
model_dict[k] = new_value
最后更新模型的参数
the_model.load_state_dict(model_dict)
如果模型的key值和在大数据集上训练时的key值是一样的
我们可以通过下列算法进行读取模型model_dict = model.state_dict() pretrained_dict = torch.load(model_path) # 1. filter out unnecessary keys diff = {k: v for k, v in model_dict.items() if \ k in pretrained_dict and pretrained_dict[k].size() == v.size()} pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()} pretrained_dict.update(diff) # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict model.load_state_dict(model_dict)
如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是一样的
model_dict = model.state_dict() pretrained_dict = torch.load(model_path) keys = [] for k,v in pretrained_dict.items(): keys.append(k) i = 0 for k,v in model_dict.items(): if v.size() == pretrained_dict[keys[i]].size(): print(k, ',', keys[i]) model_dict[k]=pretrained_dict[keys[i]] i = i + 1 model.load_state_dict(model_dict)
如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是也不一样的
自己找对应关系,一个key对应一个key的赋值相关文章推荐
- (原+译)pytorch中保存和载入模型
- pytorch 如何加载部分预训练模型
- SSD: Single Shot MultiBox Detector 模型fine-tune和网络架构
- pytorch中获取模型input/output shape
- 浅谈将Pytorch模型从CPU转换成GPU
- PyTorch中使用预训练的模型初始化网络的一部分参数
- pytorch学习笔记(十一):fine-tune 预训练的模型
- 在caffe上做FCN模型fine-tune的一些注意事项
- pytorch入门(3)pytorch-seq2seq模型
- [置顶] Caffe windows 下进行(微调)fine-tune 模型
- (原)torch模型转pytorch模型
- Caffe学习系列(13):对训练好的模型进行fine-tune
- SSD: Single Shot MultiBox Detector 模型fine-tune和网络架构
- Pytorch——把模型的所有参数的梯度清0
- 利用pytorch构建简单的CNN模型(二)
- PyTorch(7)——模型的训练和测试、保存和加载
- pytorch构建网络模型的4种方法
- 《Caffe windows 下进行(微调)fine-tune 模型》读书笔记
- pytorch学习-迁移模型
- keras入门 ---在预训练好网络模型上进行fine-tune