SSD原理及Pytorch代码解读——标签生成与损失求解
前面已经生成了6个特征图上所有的PriorBox的位置和已经将特征图转化为相应的分类和边框位置的预测值。为了可以进行训练,我们还需要进行标签的生成和计算损失。
标签生成
这一步主要是按照一定的原则,对所有的PriorBox赋予正、负样本的标签,并确定对应的真实物体标签,以方便后续损失的计算。
我们已经求得了求得8732个PriorBox坐标及对应的类别、位置预测后,首先要做的就是为每一个PriorBox贴标签,筛选出符合条件的正样本与负样本,以便进行后续的损失计算。判断依据与Faster RCNN相同,都是通过预测与真值的IoU值来判断。
筛选过程遵循以下4个原则:
- 在判断正、负样本时,IoU阈值设置为0.5,即一个PriorBox与所有真实框的最大IoU小于0.5时,判断该框为负样本。
- 判断对应关系时,将PriorBox与其拥有最大IoU的真实框作为其位置标签。
- 与真实框有最大IoU的PriorBox,即使该IoU不是此PriorBox与所有真实框IoU中最大的IoU,也要将该Box对应到真实框上,这是为了保证真实框的。
- 在预测边框位置时,是预测相对于预选框的偏移量,因此在求得匹配关系后还需要进行偏移量计算,具体公式如下:
{tx=(x−xa)waty=(y−ya)hatw=logwwath=loghha\begin{cases} t_x=\frac{(x-x_a)}{w_a} \\ t_y=\frac{(y-y_a)}{h_a} \\ t_w = \log \frac{w}{w_a} \\ t_h = \log \frac{h}{h_a}\end{cases} ⎩⎪⎪⎪⎨⎪⎪⎪⎧tx=wa(x−xa)ty=ha(y−ya)tw=logwawth=loghah
源码
# 输入包括IoU阈值、真实边框位置、预选框、方差、真实边框类别 # 输出为每一个预选框的类别,保存在conf_t中,对应的真实边框位置,保存在loc_t中 def match(threshold, truths, priors, variances, labels, loc_t, conf_t, idx): """ 对每一个PriorBox与真实边框进行重叠匹配计算,对边界框进行编码。 返回匹配的索引和对应的类别标签和坐标偏移值. Args: threshold: (float) 匹配计算时的重叠阈值. truths: (tensor) 真实边框的坐标, Shape: [num_obj, 4]. priors: (tensor) Prior boxes, Shape: [num_priors,4]. variances: (tensor) 对应于每个prior坐标的方差, Shape: [num_priors, 4]. labels: (tensor) 真实边框的类别真值, Shape: [num_obj]. loc_t: (tensor) 代填充张量,用于记录prior对应的位置偏移量,shape:[batch, num_priors, 4]. conf_t: (tensor) 代填充张量,用于记录prior对应的类别真值,shape:[batch, num_priors, 4]. idx: (int) 当前batch里样本的编号 Return: The matched indices corresponding to 1)location and 2)confidence preds. """ # 注意这里truth是最大最小值形式的,而prior是中心点与长宽形式 # 求取真实框与预选框的IoU矩阵,每一行代表一个标签,每一列代表一个prior,shape:[num_obj, num_priors] overlaps = jaccard( truths, point_form(priors) ) # 正负样本筛选 # 对每一行求最大值,得到每个真实边框与所有prior最大IoU值和最大值索引 # 返回的第一个为最大值,第二个为最大值的位置,shape:[num_objects, 1] best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) # 对每一列求最大值,得到每个prior与所有真实边框最大IoU值和最大值索引 # 返回的第一个为最大值,第二个为最大值的位置,shape:[1,num_priors] best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) best_truth_idx.squeeze_(0) # shape:[num_priors,] best_truth_overlap.squeeze_(0) # shape:[num_priors,] best_prior_idx.squeeze_(1) # shape:[num_objects,] best_prior_overlap.squeeze_(1) # shape:[num_objects,] # 将每一个truth对应的最佳box的overlap设置为2 best_truth_overlap.index_fill_(0, best_prior_idx, 2) # ensure best prior # TODO refactor: index best_prior_idx with long tensor # ensure every gt matches with its prior of max overlap # 保证每一个truth对应的最佳box,该box要对应到这个truth上,即使不是最大iou for j in range(best_prior_idx.size(0)): best_truth_idx[best_prior_idx[j]] = j # 每一个prior对应的真实框的位置 matches = truths[best_truth_idx] # Shape: [num_priors,4] # 每一个prior对应的类别, 0为负样本, 大于为正样本 conf = labels[best_truth_idx] + 1 # Shape: [num_priors] # 如果一个PriorBox对应的最大IoU小于0.5,则视为负样本 conf[best_truth_overlap < threshold] = 0 # label as background # 进一步计算定位的偏移真值 loc = encode(matches, priors, variances) loc_t[idx] = loc # [num_priors,4] encoded offsets to learn conf_t[idx] = conf # [num_priors] top class label for each prior
损失求解
整个损失函数可以分成两个部分来讲解:回归损失和分类损失
回归损失
求解回归损失比较简单,因为前面已经完成真实边框和PriorBox的匹配,知道了正负样本及每一样本对应的真实边框。还有求解时只需要计算正样本的损失就足够了。SSD使用了smoothL1作为损失函数,具体公式如下:
smoothL1(x)={0.5x2,if |x| < 1∣x∣−0.5,otherwisesmooth_{L_1}(x)=\begin{cases} 0.5x^2, & \text {if |x| < 1} \\ |x|-0.5, & \text{otherwise} \end{cases}smoothL1(x)={0.5x2,∣x∣−0.5,if |x| < 1otherwise
分类损失
一般情况下一张图片中存在的物体数量是很少的,基本上很少会超过100,而像SSD这样的采用了8732个先验框,因此存在大量负样本,如果都拿去训练,会导致正样本和负样本严重失衡。因此SSD采用的是难样本的挖掘。这里的难样本是针对负样本而言的。
Faster RCNN通过限制正负样本的数量来保持正、负样本均衡,而在SSD中,则是保证正、负样本的比例来实现样本均衡。具体做法是在计算出所有负样本的损失后进行排序,选取损失较大的那一部分进行计算,舍弃剩下的负样本,数量为正样本的3倍。
在计算完所有边框的类别交叉熵损失后,难样本挖掘过程主要分为5步:
- 过滤掉正样本
- 将负样本的损失排序
- 计算正样本的数量
- 通过正样本数量来得到得到负样本的数量
- 最后根据损失大小得到留下的负样本索引
在得到筛选后的正、负样本后,即可进行类别的损失计算。SSD在此使用了交叉熵损失函数,并且正、负样本全部参与计算。
源码
源代码文件见layers/modules/multibox_loss.py。
class MultiBoxLoss(nn.Module): """SSD 权重损失函数 计算目标: 1)计算真实边框与PriorBoxes的IoU矩阵,将真实边框与PriorBoxes匹配起来(匹配阈值默认为0.5) 2)计算PriorBoxes和对应边框的偏移真值 3) 对难样本进行挖掘,过滤大量负样本(保持正负样本为1:3) 目标损失: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N 其中,分类使用交叉熵损失。回归是SmoothL1损失,按α加权,通过交叉值设为1 Args: c: class confidences, l: predicted boxes, g: ground truth boxes N: number of matched default boxes See: https://arxiv.org/pdf/1512.02325.pdf for more details. """ def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target, use_gpu=True): super(MultiBoxLoss, self).__init__() self.use_gpu = use_gpu # 是否使用GPU self.num_classes = num_classes # 类别数量 self.threshold = overlap_thresh # Iou阈值 self.background_label = bkg_label # 背景类别,为0 self.encode_target = encode_target # self.use_prior_for_matching = prior_for_matching self.do_neg_mining = neg_mining self.negpos_ratio = neg_pos self.neg_overlap = neg_overlap #import pdb #pdb.set_trace() self.variance = cfg['variance'] def forward(self, predictions, targets): """Multibox Loss Args: predictions (tuple): A tuple containing loc preds, conf preds, and prior boxes from SSD net. conf shape: torch.size(batch_size,num_priors,num_classes) loc shape: torch.size(batch_size,num_priors,4) priors shape: torch.size(num_priors,4) targets (tensor): Ground truth boxes and labels for a batch, shape: [batch_size,num_objs,5] (last idx is the label). """ # 网络预测值,loc_data shape: [batch, num_priors, 4] # conf_data shape: [batch, num_priors, num_classes], priors shape: [num_priors, 4] loc_data, conf_data, priors = predictions num = loc_data.size(0) # 批处理大小 priors = priors[:loc_data.size(1), :] num_priors = (priors.size(0)) # priorbox总数,数值为8732 num_classes = self.num_classes # 类别数量 # 1 首先匹配正负样本 loc_t = torch.Tensor(num, num_priors, 4) # 回归偏移真值 conf_t = torch.LongTensor(num, num_priors) # 分类真值,0为负样本, >0为正样本 for idx in range(num): truths = targets[idx][:, :-1].data labels = targets[idx][:, -1].data defaults = priors.data # 得到每一个prior对应的truth,放到loc_t与conf_t中,conf_t中是类别,loc_t中是偏移真值 match(self.threshold, truths, defaults, self.variance, labels, loc_t, conf_t, idx) if self.use_gpu: loc_t = loc_t.cuda() conf_t = conf_t.cuda() # 2 计算所有正样本的定位损失,负样本不需要定位损失 # 计算正样本的数量 pos = conf_t > 0 num_pos = pos.sum(dim=1, keepdim=True) #import pdb #pdb.set_trace() # 回归损失 (Smooth L1) # Shape: [batch,num_priors,4] # 将pos_idx扩展为[32, 8732, 4],正样本的索引 pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) # 正样本的定位预测值 loc_p = loc_data[pos_idx].view(-1, 4) # 正样本的定位真值 loc_t = loc_t[pos_idx].view(-1, 4) # 所有正样本的定位损失 loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False) # 3 对于类别损失,进行难样本挖掘,控制比例为1:3 # Compute max conf across batch for hard negative mining # 所有prior(即batch内所有prior)的类别预测 batch_conf = conf_data.view(-1, self.num_classes) # 计算类别损失.每一个的log(sum(exp(21个的预测)))-对应的真正预测值 loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1)) # Hard Negative Mining loss_c = loss_c.view(pos.size()[0], pos.size()[1]) # 首先过滤掉正样本 loss_c[pos] = 0 loss_c = loss_c.view(num, -1) _, loss_idx = loss_c.sort(1, descending=True) # 从大到小 # idx_rank为排序后每个元素的排名 _, idx_rank = loss_idx.sort(1) num_pos = pos.long().sum(1, keepdim=True) # 每张图片中正样本数量之和,shape:[batch, 1] # 这个地方负样本的最大值应该是pos.size(1)-num_pos,才能保证负样本索引不会超出边界(虽然一般而言都是正样本数量很少) num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1) # 得到每个batch中负样本的索引掩码矩阵,shape[batch, 8732] # 具体的计算过程可以参考这篇博客——https://blog.csdn.net/laizi_laizi/article/details/103482634?utm_medium=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.nonecase&depth_1-utm_source=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.nonecase neg = idx_rank < num_neg.expand_as(idx_rank) # 4 计算正负样本的类别损失 # 都扩展为[batch, num_priors, num_classes] pos_idx = pos.unsqueeze(2).expand_as(conf_data) neg_idx = neg.unsqueeze(2).expand_as(conf_data) # 把正负样本的预测值提出来,shape: [batch*(pos_num+neg_num),self.num_classes] conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes) # 把正负样本的类别标签提出来,shape: [batch*(pos_num+neg_num)] targets_weighted = conf_t[(pos+neg).gt(0)] loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False) # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N N = num_pos.data.sum() # batch内所有正样本数量 loss_l /= N.type('torch.cuda.FloatTensor') loss_c /= N.type('torch.cuda.FloatTensor') return loss_l, loss_c
- Conditional GAN(条件对抗生成网络)原理及代码解读
- pytorch目标检测ssd七__训练代码与loss组成解析
- erlang抽象码与basho的protobuf(四)代码生成原理之代码生成
- 解读ssd中训练代码中知识点
- 编译原理第五章 语法制导翻译技术和中间代码生成
- 基于python3生成标签云代码解析
- SpringMVC实现原理和代码解读
- pytorch目标检测ssd六__预测效果与预测过程代码详解
- 编译原理课程设计_C--编译器_语法分析&代码生成 - Justin
- caffe层代码解读:SSD目标检测之MultiBox
- TSN算法的PyTorch代码解读(训练部分)
- 【Unity编辑器】Unity基于模板生成代码的原理与应用
- 代码生成原理研究
- SSD和Textboxes 原理以区别,以及代码的详细解释(主要是写给我自己看的)
- 自动生成图像标签代码
- php 生成短网址原理及代码
- php生成html原理核心代码
- 实时字幕生成原理挖掘——论文解读DenseCap: Fully Convolutional Localization Networks for Dense Captioning
- FCN详解与pytorch简单实现(附详细代码解读)
- 编译原理(七)中间代码生成