您的位置:首页 > 其它

PyTorch学习系列(十四)——保存训练好的模型

2017-06-02 08:48 190 查看
PyTorch提供了两种保存训练好的模型的方法。

第一种是只保存模型参数,这也是推荐的方法:

#保存
torch.save(the_model.state_dict(), PATH)
#读取
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))


第二种方法保存整个模型:

#保存
torch.save(the_model, PATH)
#读取
the_model = torch.load(PATH)


参考

[1] http://pytorch.org/docs/notes/serialization.html
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: