解决pytorch 保存模型遇到的问题
2021-03-05 04:06
190 查看
今天用pytorch保存模型时遇到bug
Can't pickle <class 'torch._C._VariableFunctions'>
在google上查找原因,发现是保存时保存了整个模型的原因,而模型中有一些自定义的参数
将 torch.save(model,save_path)
改为 torch.save(model.state_dict(),save_path)
然后载入模型也做相应的更改就好了
补充:pytorch训练模型的一些坑
1. 图像读取
opencv的python和c++读取的图像结果不一致,是因为python和c++采用的opencv版本不一样,从而使用的解码库不同,导致读取的结果不同。
2. 图像变换
PIL和pytorch的图像resize操作,与opencv的resize结果不一样,这样会导致训练采用PIL,预测时采用opencv,结果差别很大,尤其是在检测和分割任务中比较明显。
3. 数值计算
pytorch的torch.exp与c++的exp计算,10e-6的数值时候会有10e-3的误差,对于高精度计算需要特别注意,比如
两个输入5.601597, 5.601601, 经过exp计算后变成270.85862343143174, 270.85970686809225
以上为个人经验,希望能给大家一个参考如有错误或未考虑完全的地方,望不吝赐教。
您可能感兴趣的文章:相关文章推荐
- 解决Pytorch 加载训练好的模型 遇到的error问题
- jQuery遇到的问题之刷新后不保存(待解决)
- Android歌词保存到本地及读取所遇到的字符乱码问题及解决
- 解决tensorflow模型参数保存和加载的问题
- c++调用pytorch的模型遇到的问题
- tensorflow模型运行遇到的问题以及解决办法:NotFoundError: Key Variable_10 not found in checkpoint
- 在SharePoint中服务器端使用Word编程模型转换PDF遇到的问题以及解决方法
- 求助 3ds max模型导入vrp的问题 求大神指教 等 这个问题大家有遇到过吗 应该怎么解决啊
- [Nebula2]使用3dmax7,nmaxtoolbox导出模型到nebula2遇到的问题及解决方法
- python第一个爬虫小程序以及遇到问题解决(中文乱码)+批量爬取网页并保存至本地
- PyTorch 解决Dataset和Dataloader遇到的问题
- 亚马逊和脸书发布 TorchServe 解决Pytorch 模型的部署问题
- 使用MySQL保存中文数据时,经常会遇到乱码问题的解决思路
- Matlab保存图像过程中遇到的问题和一些解决办法
- TensorFlow学习笔记10——TensorFlow保存和调用模型遇到的问题
- 解决学习tensorflow的LSTM模型中遇到一个版本不兼容问题
- 关于模型数组进行本地保存中遇到的一些问题
- windows系统conda高效安装pytorch以及解决遇到的下载问题
- tensorflow模型运行遇到的问题以及解决办法:NotFoundError: Key Variable_10 not found in checkpoint
- pytorch中关于GAN代码遇到的问题解决