您的位置:首页 > 其它

pytorch 如何加载部分预训练模型

2017-03-19 14:55 661 查看
pretrained_dict =...

model_dict = model.state_dict()

# 1. filter out unnecessary keys

pretrained_dict = {k: v for k, vin pretrained_dict.items() if k inmodel_dict}

# 2. overwrite entries in the existing state dict

model_dict.update(pretrained_dict)

# 3. load the new state dict

model.load_state_dict(model_dict)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: