ssd.pytorch源码分析(三)— 非极大值抑制NMS
2019-05-15 19:20
399 查看
NMS介绍
吴恩达对于NMS(非极大值抑制)的介绍:
说白了,NMS的作用就是去掉目标检测任务重复的检测框。 例如,一个目标有多个选择框,现在要去掉多余的选择框。怎么做呢?循环执行步骤1和2, 直到只剩下一个框:
- 1、选出置信度p_c最高的框;
- 2、去掉和这个框IOU>0.7的框。
相关函数
一、torch.clamp( )
torch.clamp(input, min, max, out=None) → Tensor
将输入input张量每个元素夹紧到区间 [min,max],并返回结果到一个新张量。
操作定义如下:
| min, if x_i < min y_i = | x_i, if min <= x_i <= max | max, if x_i > max
参数:
- input (Tensor) – 输入张量
- min (Number) – 限制范围下限
- max (Number) – 限制范围上限
- out (Tensor, optional) – 输出张量
例子:
>>> a = torch.randn(4) >>> a 1.3869 0.3912 -0.8634 -0.5468 [torch.FloatTensor of size 4] >>> torch.clamp(a, min=-0.5, max=0.5) 0.5000 0.3912 -0.5000 -0.5000 [torch.FloatTensor of size 4]
二、torch.index_select()
torch.index_select(input, dim, index, out=None) → Tensor
沿着指定维度对输入进行切片。
参数:
- input (Tensor) – 输入张量
- dim (int) – 索引的轴
- index (LongTensor) – 包含索引下标的一维张量
- out (Tensor, optional) – 目标张量
例子:
>>> x = torch.randn(3, 4) >>> x 1.2045 2.4084 0.4001 1.1372 0.5596 1.5677 0.6219 -0.7954 1.3635 -1.2313 -0.5414 -1.8478 [torch.FloatTensor of size 3x4] >>> indices = torch.LongTensor([0, 2]) >>> torch.index_select(x, 0, indices) 1.2045 2.4084 0.4001 1.1372 1.3635 -1.2313 -0.5414 -1.8478 [torch.FloatTensor of size 2x4] >>> torch.index_select(x, 1, indices) 1.2045 0.4001 0.5596 0.6219 1.3635 -0.5414 [torch.FloatTensor of size 3x2]
注意,index_select函数中的参数index表示了有哪些索引值是需要保留的。
三、 torch.numel()
torch.numel(input)->int
返回input 张量中的元素个数。
复现代码
以下为ssd.pytorch中NMS(实际上在任何anchor based的目标检测框架中都适用)。其中:
- 为了减少计算量,作者仅选取置信度前top_k=200个框;
- 代码中包含了IOU的计算。关于IOU计算推荐阅读这篇文章;
def nms(boxes, scores, overlap=0.7, top_k=200): """ 输入: boxes: 存储一个图片的所有预测框。[num_positive,4]. scores:置信度。如果为多分类则需要将nms函数套在一个循环内。[num_positive]. overlap: nms抑制时iou的阈值. top_k: 先选取置信度前top_k个框再进行nms. 返回: nms后剩余预测框的索引. """ keep = scores.new(scores.size(0)).zero_().long() # 保存留下来的box的索引 [num_positive] # 函数new(): 构建一个有相同数据类型的tensor #如果输入box为空则返回空Tensor if boxes.numel() == 0: return keep x1 = boxes[:, 0] y1 = boxes[:, 1] x2 = boxes[:, 2] y2 = boxes[:, 3] area = torch.mul(x2 - x1, y2 - y1) #并行化计算所有框的面积 v, idx = scores.sort(0) # 升序排序 idx = idx[-top_k:] # 前top-k的索引,从小到大 xx1 = boxes.new() yy1 = boxes.new() xx2 = boxes.new() yy2 = boxes.new() w = boxes.new() h = boxes.new() count = 0 while idx.numel() > 0: i = idx[-1] # 目前最大score对应的索引 keep[count] = i #存储在keep中 count += 1 if idx.size(0) == 1: #跳出循环条件:box被筛选完了 break idx = idx[:-1] # 去掉最后一个 #剩下boxes的信息存储在xx,yy中 torch.index_select(x1, 0, idx, out=xx1) torch.index_select(y1, 0, idx, out=yy1) torch.index_select(x2, 0, idx, out=xx2) torch.index_select(y2, 0, idx, out=yy2) # 计算当前最大置信框与其他剩余框的交集,作者这段代码写的不好,容易误导 xx1 = torch.clamp(xx1, min=x1[i]) #max(x1,xx1) yy1 = torch.clamp(yy1, min=y1[i]) #max(y1,yy1) xx2 = torch.clamp(xx2, max=x2[i]) #min(x2,xx2) yy2 = torch.clamp(yy2, max=y2[i]) #min(y2,yy2) w.resize_as_(xx2) h.resize_as_(yy2) w = xx2 - xx1 #w=min(x2,xx2)−max(x1,xx1) h = yy2 - yy1 #h=min(y2,yy2)−max(y1,yy1) w = torch.clamp(w, min=0.0) #max(w,0) h = torch.clamp(h, min=0.0) #max(h,0) inter = w*h #计算当前最大置信框与其他剩余框的IOU # IoU = i / (area(a) + area(b) - i) rem_areas = torch.index_select(area, 0, idx) # 剩余的框的面积 union = rem_areas + area[i]- inter #并集 IoU = inter/union # 计算iou # 选出IoU <= overlap的boxes(注意le函数的使用) idx = idx[IoU.le(overlap)] return keep, count #[num_remain], num_remain
相关文章推荐
- ssd.pytorch源码分析(五)—损失函数及Hard negative mining
- Pytorch SSD模型分析
- py-faster-rcnn源码解读(一)NMS (非极大值抑制)
- Faster RCNN 源码解读(2) -- NMS(非极大抑制)
- PyTorch--双向递归神经网络(B-RNN)概念,源码分析
- web.py源码分析: application
- 非极大值抑制(Non-maximum suppression, NMS)
- 从pytorch的transfer learning tutorial讲分类任务的数据读取(深入分析torchvision.datasets.ImageFolder源码)
- Ciclop开源3D扫描仪软件---Horus源码分析之src\horus\engine\calibration\laser_triangulation.py
- pytorch学习笔记(十五):pytorch 源码编译碰到的坑总结
- 非极大值抑制(Non-Maximum Suppression,NMS)
- pyalgotrade源码分析4--PyAlgoTrade统计指标
- openstack nova 源码分析5-2 -nova/virt/libvirt目录下的connection.py
- Pytorch源码安装(附加可能出现的问题解决)
- 非极大值抑制(NMS)
- 非极大值抑制(Non-Maximum-Suppression, NMS)
- Attention is all you need pytorch实现 源码解析02 - 模型的训练(1)- 模型的训练代码
- 巡风源码阅读与分析---nascan.py
- Tensorflow Object Detection API 源码分析之 model_main.py
- 非极大值抑制算法 NMS