pytorch 模型部分参数的加载
2019-02-20 15:47
387 查看
如果对预训练模型的结构进行了一些改动,在训练的开始前希望加载未改动部分的参数,如将resnet18的第一层卷积层conv1的输入由3通道改为6通道的new_conv1,将分类层fc的1000类输出改为2类输出的new_fc,注意:要改一下名字与原来的不同。
导入模型
myNet=ResNet()
然后就加载模型的参数,参考pytorch 如何加载部分预训练模型
pretrained_dict=torch.load(model_weight)
model_dict=myNet.state_dict()
1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
myNet.load_state_dict(model_dict)
也可以通过pretrained model.state_dict()提取需要的模型参数。
作者:lxx516
来源:CSDN
原文:https://blog.csdn.net/LXX516/article/details/80124768
版权声明:本文为博主原创文章,转载请附上博文链接!
相关文章推荐
- pytorch 模型部分参数的加载
- tensorflow之inception_v3模型的部分加载及权重的部分恢复(23)---《深度学习》
- tensorflow模型参数保存和加载问题
- 解决pytorch中DataParallel后模型参数出现问题的方法
- tensorflow模型参数与结构的加载-----二
- 【深度学习】tensorflow加载VGG16的网络结构和模型参数
- ReportingService错误:配置参数 SharePointIntegrated 被设置为 True,但无法加载 Share Point 对象模型
- 【pytorch】模型的搭建保存加载
- TF Saver 保存/加载训练好模型(网络+参数)的那些事儿
- 解决tensorflow模型参数保存和加载的问题
- pytorch 模型的保存和加载
- pytorch 如何加载部分预训练模型
- [TensorFlow深度学习入门]实战八·简便方法实现TensorFlow模型参数保存与加载(pb方式)
- Tensorflow: 如何加载神经网络部分参数
- Keras模型的加载和保存、预训练、按层名匹配参数
- tensorflow: 保存和加载模型, 参数;以及使用预训练参数方法
- pytorch 模型的加载
- pytorch中的pre-train函数模型引用及修改(增减网络层,修改某层参数等)
- 4000 TensorFlow学习笔记(2)——保存和加载训练模型参数
- Python机器学习----第4部分 模型评估和参数调优