您的位置:首页 > 其它

ssd.pytorch源码分析(三)— 非极大值抑制NMS

2019-05-15 19:20 399 查看

NMS源码
SSD论文链接

NMS介绍

吴恩达对于NMS(非极大值抑制)的介绍:

说白了,NMS的作用就是去掉目标检测任务重复的检测框。 例如,一个目标有多个选择框,现在要去掉多余的选择框。怎么做呢?循环执行步骤1和2, 直到只剩下一个框:

  • 1、选出置信度p_c最高的框;
  • 2、去掉和这个框IOU>0.7的框。

相关函数

一、torch.clamp( )

torch.clamp(input, min, max, out=None) → Tensor

将输入input张量每个元素夹紧到区间 [min,max],并返回结果到一个新张量。
操作定义如下:

| min, if x_i < min
y_i = | x_i, if min <= x_i <= max
| max, if x_i > max

参数:

  • input (Tensor) – 输入张量
  • min (Number) – 限制范围下限
  • max (Number) – 限制范围上限
  • out (Tensor, optional) – 输出张量

例子:

>>> a = torch.randn(4)
>>> a
1.3869
0.3912
-0.8634
-0.5468
[torch.FloatTensor of size 4]
>>> torch.clamp(a, min=-0.5, max=0.5)
0.5000
0.3912
-0.5000
-0.5000
[torch.FloatTensor of size 4]

二、torch.index_select()

torch.index_select(input, dim, index, out=None) → Tensor

沿着指定维度对输入进行切片。

参数:

  • input (Tensor) – 输入张量
  • dim (int) – 索引的轴
  • index (LongTensor) – 包含索引下标的一维张量
  • out (Tensor, optional) – 目标张量

例子:

>>> x = torch.randn(3, 4)
>>> x

1.2045  2.4084  0.4001  1.1372
0.5596  1.5677  0.6219 -0.7954
1.3635 -1.2313 -0.5414 -1.8478
[torch.FloatTensor of size 3x4]

>>> indices = torch.LongTensor([0, 2])
>>> torch.index_select(x, 0, indices)

1.2045  2.4084  0.4001  1.1372
1.3635 -1.2313 -0.5414 -1.8478
[torch.FloatTensor of size 2x4]

>>> torch.index_select(x, 1, indices)

1.2045  0.4001
0.5596  0.6219
1.3635 -0.5414
[torch.FloatTensor of size 3x2]

注意,index_select函数中的参数index表示了有哪些索引值是需要保留的。

三、 torch.numel()

torch.numel(input)->int

返回input 张量中的元素个数。

复现代码

以下为ssd.pytorch中NMS(实际上在任何anchor based的目标检测框架中都适用)。其中:

  • 为了减少计算量,作者仅选取置信度前top_k=200个框;
  • 代码中包含了IOU的计算。关于IOU计算推荐阅读这篇文章
def nms(boxes, scores, overlap=0.7, top_k=200):
"""
输入:
boxes: 存储一个图片的所有预测框。[num_positive,4].
scores:置信度。如果为多分类则需要将nms函数套在一个循环内。[num_positive].
overlap: nms抑制时iou的阈值.
top_k: 先选取置信度前top_k个框再进行nms.
返回:
nms后剩余预测框的索引.
"""

keep = scores.new(scores.size(0)).zero_().long()
# 保存留下来的box的索引 [num_positive]
# 函数new(): 构建一个有相同数据类型的tensor

#如果输入box为空则返回空Tensor
if boxes.numel() == 0:
return keep

x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
area = torch.mul(x2 - x1, y2 - y1) #并行化计算所有框的面积
v, idx = scores.sort(0)  # 升序排序
idx = idx[-top_k:]  # 前top-k的索引,从小到大
xx1 = boxes.new()
yy1 = boxes.new()
xx2 = boxes.new()
yy2 = boxes.new()
w = boxes.new()
h = boxes.new()

count = 0
while idx.numel() > 0:
i = idx[-1]  # 目前最大score对应的索引
keep[count] = i #存储在keep中
count += 1
if idx.size(0) == 1: #跳出循环条件:box被筛选完了
break
idx = idx[:-1]  # 去掉最后一个

#剩下boxes的信息存储在xx,yy中
torch.index_select(x1, 0, idx, out=xx1)
torch.index_select(y1, 0, idx, out=yy1)
torch.index_select(x2, 0, idx, out=xx2)
torch.index_select(y2, 0, idx, out=yy2)

# 计算当前最大置信框与其他剩余框的交集,作者这段代码写的不好,容易误导
xx1 = torch.clamp(xx1, min=x1[i])  #max(x1,xx1)
yy1 = torch.clamp(yy1, min=y1[i])  #max(y1,yy1)
xx2 = torch.clamp(xx2, max=x2[i])  #min(x2,xx2)
yy2 = torch.clamp(yy2, max=y2[i])  #min(y2,yy2)
w.resize_as_(xx2)
h.resize_as_(yy2)
w = xx2 - xx1 #w=min(x2,xx2)−max(x1,xx1)
h = yy2 - yy1 #h=min(y2,yy2)−max(y1,yy1)
w = torch.clamp(w, min=0.0) #max(w,0)
h = torch.clamp(h, min=0.0) #max(h,0)
inter = w*h

#计算当前最大置信框与其他剩余框的IOU
# IoU = i / (area(a) + area(b) - i)
rem_areas = torch.index_select(area, 0, idx)  # 剩余的框的面积
union = rem_areas + area[i]- inter #并集
IoU = inter/union  # 计算iou

# 选出IoU <= overlap的boxes(注意le函数的使用)
idx = idx[IoU.le(overlap)]
return keep,          count
#[num_remain], num_remain
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: