您的位置:首页 > 其它

解决Mask RCNN的pytorch版本训练时候RuntimeError: Error(s) in loading state_dict for GeneralizedRCNN:

2020-06-06 07:21 12 查看

解决训练时候的RuntimeError: Error(s) in loading state_dict for GeneralizedRCNN:

- 此处错误是在Linux环境下maskrcnn的pytorch版本中遇到的

在训练完成mask rcnn模型后进行图像识别时候出现错误为

Traceback (most recent call last):
File "/home/wangfan/maskrcnn-benchmark/demo/person.py", line 74, in <module>
confidence_threshold=0.7,
File "/home/wangfan/maskrcnn-benchmark/demo/fcpredictorbox.py", line 73, in __init__
_ = checkpointer.load(cfg.MODEL.WEIGHT)
File "/home/wangfan/maskrcnn-benchmark/maskrcnn_benchmark/utils/checkpoint.py", line 62, in load
self._load_model(checkpoint)
File "/home/wangfan/maskrcnn-benchmark/maskrcnn_benchmark/utils/checkpoint.py", line 98, in _load_model
load_state_dict(self.model, checkpoint.pop("model"))
File "/home/wangfan/maskrcnn-benchmark/maskrcnn_benchmark/utils/model_serialization.py", line 80, in load_state_dict
model.load_state_dict(model_state_dict)
File "/home/wangfan/anaconda3/envs/maskrcnn/lib/python3.7/site-packages/torch/nn/modules/module.py", line 830, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for GeneralizedRCNN:
size mismatch for roi_heads.box.predictor.cls_score.weight: copying a param with shape torch.Size([4, 1024]) from checkpoint, the shape in current model is torch.Size([6, 1024]).
size mismatch for roi_heads.box.predictor.cls_score.bias: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([6]).
size mismatch for roi_heads.box.predictor.bbox_pred.weight: copying a param with shape torch.Size([16, 1024]) from checkpoint, the shape in current model is torch.Size([24, 1024]).
size mismatch for roi_heads.box.predictor.bbox_pred.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([24]).
size mismatch for roi_heads.mask.predictor.mask_fcn_logits.weight: copying a param with shape torch.Size([4, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([6, 256, 1, 1]).
size mismatch for roi_heads.mask.predictor.mask_fcn_logits.bias: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([6]).

截图显示如下

表示训练的模型分类数量和设置中需要分类的模型数量不一致

maskrcnnbenchmark/maskrcnn_benchmark/config
路径下找到
defaults.py
文件
因此需要在defaults.py文件中将

_C.MODEL.ROI_BOX_HEAD.NUM_CLASSES = 5+1

改为自己训练模型设定的分类数量,

之前为5+1类别所以设定为,本文训练区分三类,加上背景一类为:3+1,

_C.MODEL.ROI_BOX_HEAD.NUM_CLASSES = 3+1

然后再运行图像测试程序
正确
如图所示

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐