Faster RCNN原理及Pytorch代码解读——RPN(二):RPN网络结构
2020-07-01 17:37
34 查看
上一篇介绍了Anchor,这篇开始介绍RPN。整个RPN模块其实总的来说可以分成两部分来讲,一个是RPN的训练,这个很容易理解,网络肯定是需要训练优化的;另一个是生成候选区域,这个是RPN的主要功能。现在让我们整体看一下RPN的网络组成部分,后面再详解前面说的两个部分。
RPN网络结构
RPN的网络结构很简单,如下图所示
整个网络一共有两个分支,左边是分类分支,右边是回归分支。具体实现时, 在feature map上首先用3×3的卷积进行更深的特征提取, 然后利用1×1的卷积分别实现分类网络和回归网络。
分类分支
在物体检测中, 通常我们将有物体的位置称为前景, 没有物体的位置称为背景。 在分类网络分支中, 首先使用1×1卷积输出18×37×50的特征, 由于每个点默认有9个Anchors, 并且每个Anchor只预测其属于前景还是背景, 因此通道数为18。 随后利用torch.view()函数将特征映射到2×333×75, 这样第一维仅仅是一个Anchor的前景背景得分, 并送到Softmax函数中进行概率计算, 得到的特征再变换到18×37×50的维度,最终输出的是每个Anchor属于前景与背景的概率。
回归分支
回归分支中, 利用1×1卷积输出36×37×50的特征, 第一维的36包含9个Anchors的预测, 每一个Anchor有4个数据, 分别代表了每一个Anchor的中心点横纵坐标及宽高这4个量相对于真值的偏移量。
代码
源代码文件见lib/model/rpn/rpn.py。
class _RPN(nn.Module): """ region proposal network """ def __init__(self, din): super(_RPN, self).__init__() self.din = din # 特征图的深度, e.g., 512 self.anchor_scales = cfg.ANCHOR_SCALES # 锚框的尺度变化量,默认是[8, 16 ,32] self.anchor_ratios = cfg.ANCHOR_RATIOS # 锚框的宽高变化量,默认是[0.5, 1 ,2] self.feat_stride = cfg.FEAT_STRIDE[0] # 特征图的下采样倍数,默认是16 # 定义conv层处理输入特征映射 self.RPN_Conv = nn.Conv2d(self.din, 512, 3, 1, 1, bias=True) # 定义前景/背景分类层 self.nc_score_out = len(self.anchor_scales) * len(self.anchor_ratios) * 2 # 2(bg/fg) * 9 (anchors) self.RPN_cls_score = nn.Conv2d(512, self.nc_score_out, 1, 1, 0) # 定义锚框偏移量预测层 self.nc_bbox_out = len(self.anchor_scales) * len(self.anchor_ratios) * 4 # 4(coords) * 9 (anchors) self.RPN_bbox_pred = nn.Conv2d(512, self.nc_bbox_out, 1, 1, 0) # 定义区域生成模块 self.RPN_proposal = _ProposalLayer(self.feat_stride, self.anchor_scales, self.anchor_ratios) # 定义生成RPN训练标签模块,仅在训练时使用 self.RPN_anchor_target = _AnchorTargetLayer(self.feat_stride, self.anchor_scales, self.anchor_ratios) # RPN分类损失以及回归损失,仅在训练时计算 self.rpn_loss_cls = 0 self.rpn_loss_box = 0 @staticmethod def reshape(x, d): # 用于修改张量形状 input_shape = x.size() x = x.view( input_shape[0], int(d), int(float(input_shape[1] * input_shape[2]) / float(d)), input_shape[3] ) return x def forward(self, base_feat, im_info, gt_boxes, num_boxes): # 输入数据的第一维是batch数 batch_size = base_feat.size(0) # 首先利用3×3卷积进一步融合特征 rpn_conv1 = F.relu(self.RPN_Conv(base_feat), inplace=True) # 利用1×1卷积得到分类网络,每个点代表anchor的前景背景得分 rpn_cls_score = self.RPN_cls_score(rpn_conv1) # 利用reshape与softmax得到anchor的前景背景概率 rpn_cls_score_reshape = self.reshape(rpn_cls_score, 2) rpn_cls_prob_reshape = F.softmax(rpn_cls_score_reshape, 1) rpn_cls_prob = self.reshape(rpn_cls_prob_reshape, self.nc_score_out) # 利用1×1卷积得到回归网络,每一个点代表anchor的偏移 rpn_bbox_pred = self.RPN_bbox_pred(rpn_conv1) # 区域生成模块 cfg_key = 'TRAIN' if self.training else 'TEST' rois = self.RPN_proposal((rpn_cls_prob.data, rpn_bbox_pred.data, im_info, cfg_key)) self.rpn_loss_cls = 0 self.rpn_loss_box = 0 # 生成RPN训练时的标签和计算RPN的loss if self.training: assert gt_boxes is not None # 生成训练标签 rpn_data = self.RPN_anchor_target((rpn_cls_score.data, gt_boxes, im_info, num_boxes)) # 计算分类损失 rpn_cls_score = rpn_cls_score_reshape.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, 2) rpn_label = rpn_data[0].view(batch_size, -1) rpn_keep = rpn_label.view(-1).ne(-1).nonzero().view(-1) rpn_cls_score = torch.index_select(rpn_cls_score.view(-1,2), 0, rpn_keep) rpn_label = torch.index_select(rpn_label.view(-1), 0, rpn_keep.data) rpn_label = rpn_label.long() # 先对scores进行筛选得到256个样本的得分,随后进行交叉熵求解 self.rpn_loss_cls = F.cross_entropy(rpn_cls_score, rpn_label) fg_cnt = torch.sum(rpn_label.data.ne(0)) rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights = rpn_data[1:] # 利用smoothl1损失函数进行loss计算 self.rpn_loss_box = _smooth_l1_loss(rpn_bbox_pred, rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights, sigma=3, dim=[1,2,3]) return rois, self.rpn_loss_cls, self.rpn_loss_box
相关文章推荐
- Faster RCNN原理及Pytorch代码解读——RPN(六):进一步筛选得到最终候选框
- Faster RCNN原理及Pytorch代码解读——RPN(五):生成候选区域
- Faster RCNN原理及Pytorch代码解读——RPN(三):RPN训练标签的生成
- Faster RCNN原理及Pytorch代码解读——RPN(四):损失函数
- Faster RCNN原理及Pytorch代码解读——RoI Polling
- Faster RCNN原理及Pytorch代码解读——全连接RCNN模块
- SSD原理及Pytorch代码解读——网络架构(一):基础结构
- faster-rcnn 之 RPN网络的结构解析以及RPN代码详解
- faster-rcnn 之 RPN网络的结构解析以及RPN代码详解
- Faster RCNN 源码解读(3.1) -- RPN源码结构介绍
- SSD原理及Pytorch代码解读——网络架构(二):特征提取网络及总体计算过程
- faster-rcnn 之 RPN网络的结构解析
- FASTER-CNN中RPN原理
- SSD原理及Pytorch代码解读——标签生成与损失求解
- faster rcnn源码理解(二)之AnchorTargetLayer(网络中的rpn_data)
- faster-rcnn 之 RPN网络的结构解析
- 多目标检测与识别 YOLOV3 解读3 网络结构及实现(PyTorch)
- faster-rcnn 之 RPN网络的结构解析
- Faster RCNN 源码解读(1) -- 文件结构分析
- faster-rcnn 中的RPN网络的结构解析