yolov5 极大值抑制 nms 代码详解

NMS介绍

一张图片经过模型的前向推理之后,会输出非常多的检测结果,如下图Prediction所示:

以一组数据来说明,推理一张图片,模型输入大小为640*640,2个类别,推理的结果产生的预测框计算:

  • 20 * 20 * 3
  • 40 * 40 * 3
  • 80 * 80 * 3

总计:25200
参数解释:

  • 20、40、80 代表生成的不同预测网格尺度,yolov5会输出3种尺度的预测结果,分别大尺度20*20,中尺度40*40,小尺度80*80
  • 3 代表每一个网格有3个预测框bounding box

所以要从这25200个数据中筛选出最好的两个结果,因为检查两个类别。筛选的过程就是NMS。

NMS: 非极大值抑制(Non-Maximum Suppression),
功能:从大量的预测结果中筛选出得分最高的结果。
思路:NMS的主要思路是通过计算目标框之间的重叠度(即IOU,交并比)来剔除非最佳结果。
代码实现:
YOLOv5的NMS代码使用了一个循环来遍历所有的预测框,并进行判断和筛选。

  1. 通过设定置信度阈值和IOU阈值,将预测框中置信度低于阈值的框过滤掉,只保留置信度高的框。
  2. 对剩下的框按照置信度进行降序排序,确保置信度高的框排在前面。
  3. 从置信度最高的框开始,计算它和剩下所有预测框的IOU。剩下的预测框中IOU低于设定的IOU阈值保留预测框。
  4. 重复上述步骤,直到遍历完所有的预测框,并得到最终筛选出来的目标框。

获取参数

传入的 prediction 为预测框,形状为:[29, 27783, 7] ,代表着推理的图片一共29张,每一张图片预测结果有27783个预测框,7代表预测结果x,y,w,h,confidence, cls1, cls2。
首先获取参数并校验参数。conf_thres、iou_thres都是影响剩余预测框的参数,校验其参数都在0-1之间。然后获取batch_size、类别个数,最后判断预测结果中大于设定阈值的框,得到一个布尔张量。

def non_max_suppression(
        prediction,
        conf_thres=0.25,
        iou_thres=0.45,
        classes=None,
        agnostic=False,
        multi_label=False,
        labels=(),
        max_det=300,
        nm=0,  # number of masks
):
    """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections

    Returns:
         list of detections, on (n,6) tensor per image [xyxy, conf, cls]
    """

    """
    Pdb) pp prediction.shape
    torch.Size([29, 27783, 7])
    """

    # Checks
    assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
    assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
    if isinstance(prediction, (list, tuple)):  # YOLOv5 model in validation model, output = (inference_out, loss_out)
        prediction = prediction[0]  # select only inference output

    """
    (Pdb) bs
    29
    (Pdb) nc
    2
    (Pdb) xc.shape
    torch.Size([29, 27783])
    """

    bs = prediction.shape[0]  # batch size
    nc = prediction.shape[2] - nm - 5  # number of classes
    xc = prediction[..., 4] > conf_thres  # candidates

设定参数上限

设定一些nms中参数上线,主要用于在超过范围时限制。

  • max_wh: 图像最大的宽高
  • max_nms: 进入筛选预测框的上线
  • time_limit: 时间限制
  • multi_label: 多标签标识
  • merge:merge-nms标识
max_wh = 7680  # (pixels) maximum box width and height
max_nms = 30000  # maximum number of boxes into torchvision.ops.nms()
time_limit = 0.5 + 0.05 * bs  # seconds to quit after
redundant = True  # require redundant detections
multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
merge = False  # use merge-NMS

过滤出大于阈值的预测框

构建结果张量output,每一行由6个参数构成,分别是x,y,w,h,confidence,cls。这个就是模型推理的最后结果。
循环遍历prediction,xi是从0开始的下标,x是27783个预测结果。xc[xi] 获取阈值筛选的正负结果,x[xc[xi]] 通过布尔数组获取其中为True的结果,也就是置信度大于阈值的结果。

t = time.time()
mi = 5 + nc  # mask start index
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
for xi, x in enumerate(prediction):  # image index, image inference
    # Apply constraints
    # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0  # width-height

    """
    (Pdb) xc.shape
    torch.Size([29, 27783])
    
    (Pdb) xc[0]
    tensor([False, False, False,  ...,  True,  True,  True], device='cuda:0')\
    
    (Pdb) xc[0].shape
    torch.Size([27783])
    
    (Pdb) x.shape
    torch.Size([27783, 7])
    """
    x = x[xc[xi]]  # confidence

判断结果中没有大于阈值的结果,则跳过下面的流程

# If none remain process next image
if not x.shape[0]:
    continue

置信度处理

置信度在模型推理的结果上还有进一步的处理,将置信度 * 类别得分才是真正的confidence。
x代表的是预测结果,每一行的数据分别是:x, y, w, h, conf, cls。
x[:, 5:] 获取所有类别得分,x[:, 4:5] 获取置信度得分。两者乘积才是最终的confidence。

# Compute conf
x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf

获取结果中的预测框位置信息,并将位置信息从xywh转换成x1y1x2y2的格式。

"""
(Pdb) pp box[0]
tensor([-5.80469, -2.92188, 19.59375, 25.15625], device='cuda:0', dtype=torch.float16)
"""
box = xywh2xyxy(x[:, :4])  # center_x, center_y, width, height) to (x1, y1, x2, y2)
mask = x[:, mi:]  # zero columns if no masks

准备NMS计算的结构

在类别大于1的情况下,multi_label 为真,首先判断上一步confidence的结果中大于阈值的数据,并获取保留结果的预测框坐标轴i和j。 x为二维向量,i 代表 大于阈值的预测框所在的行,j代表所在的列。然后将位置信息、confidence、类别信息连接成新的向量。其中:

  • box[i]:xywh
  • x[i, 5 + j, None]: 获得了所有大于阈值预测框的confidence
  • j[:, None]: 类别ID
# Detections matrix nx6 (xyxy, conf, cls)
if multi_label:
    """
    x[:, 5:mi] > conf_thres: 大于阈值的结果
    
    i:bbox的横坐标
    j: bbox的纵坐标
    
    
    box[i]:xywh
    x[i, 5 + j, None]: 获得了所有大于阈值的bbox结果
    j[:, None]: 类别ID
    
    """
    i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T
    x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1)
    """
    (Pdb) pp x.shape
    torch.Size([17039, 6])
    """

else:  # best class only
    conf, j = x[:, 5:mi].max(1, keepdim=True)
    x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]

# Filter by class
if classes is not None:
    x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

准备boxes 和 scores

对置信度排序,获取从大到小的预测框的下标,并过滤超出范围的预测框。对多类别中不同类别的预测框做一个偏移操作,防止不同类别的预测框互相影响。最后得到预测框坐标 boxes 和 置信度 scores。

# Check shape
n = x.shape[0]  # number of boxes
if not n:  # no boxes
    continue

# 对置信度排序,并去掉超出范围的结果
x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence and remove excess boxes

# Batched NMS
# 这行代码是多类别中应用NMS
# 多类别NMS(非极大值抑制)的处理策略是为了让每个类都能独立执行NMS,在所有的边框上添加一个偏移量。
# 偏移量仅取决于类的ID(也就是x[:, 5:6]),并且足够大,以便来自不同类的框不会重叠。

c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes

# 获取boxes 和 scores
boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
"""
(Pdb) pp boxes.shape
torch.Size([17039, 4])
(Pdb) pp scores.shape
torch.Size([17039])
"""

计算结果

nms 筛选。nms筛选使用的是torchvision的工具,传入预测框位置,置信度,iou阈值,得到一个一维张量。

# 遍历score,通过IOU 筛选bbox。返回最终的bbox的index
i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
"""
(Pdb) i.shape
torch.Size([300])
"""

torchvision.ops.nms 可以用纯python的操作代替。如下代码就是完成nms筛选的过程。主要流程:

  1. 传入预测框
  2. 计算所有预测框的面积
  3. 对置信度从大到小排序
  4. 循环筛选
  • 保存预测框第一个值,也就是最大置信度的值
  • 计算该值和剩余所有预测框的IOU,保存小于iou_thresh的预测框,大于的丢弃
  • 循环以上操作,直到所有的预测框都筛选完

# @jit(nopython=True)
def py_cpu_nms(dets, scores, conf_thresh=0.1, iou_thresh=0.5):
    """Pure Python NMS baseline."""

    x1 = dets[:, 0]
    y1 = dets[:, 1]
    x2 = dets[:, 2]
    y2 = dets[:, 3]

    """
    计算出所有框的面积
    """
    areas = (x2 - x1 + 1) * (y2 - y1 + 1)

    # 对置信度从小到大的排序,返回排序之后的下标的倒排,也就是从大到小的下标。
    order = scores.argsort()[::-1]

    # 创建数组保存最终结果,bbox的下标
    keep = []
    while order.size > 0:

        # 选择当前最大置信度的下标,第一循环为最大的
        i = order[0]
        keep.append(i)

        # 计算当前bbox和剩下所有bbox的IOU
        xx1 = np.maximum(x1[i], x1[order[1:]])
        yy1 = np.maximum(y1[i], y1[order[1:]])
        xx2 = np.minimum(x2[i], x2[order[1:]])
        yy2 = np.minimum(y2[i], y2[order[1:]])

        w = np.maximum(0.0, xx2 - xx1 + 1)
        h = np.maximum(0.0, yy2 - yy1 + 1)
        inter = w * h
        ovr = inter / (areas[i] + areas[order[1:]] - inter)

        # 筛选出IOU小于某一个阈值的bbox,也就是说过滤掉所有IOU大于阈值的bbox
        inds = np.where(ovr <= iou_thresh)[0]

        # 重置order数组,丢弃所有和当前bbox的IOU大于阈值的bbox
        order = order[inds + 1]

    # 重复当前动作,每一次添加一个最大置信度的bbox到keep,直到所有的数组都遍历完成。

    return keep

返回最终结果

收尾工作中获取预测值,打印处理时间等

i = i[:max_det]  # limit detections
if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)
    # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
    iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
    weights = iou * scores[None]  # box weights
    x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
    if redundant:
        i = i[iou.sum(1) > 1]  # require redundancy

# 获取某一个网格的最终结果
output[xi] = x[i]
"""
(Pdb) p output[xi].shape
torch.Size([300, 6])
"""

if mps:
    output[xi] = output[xi].to(device)
if (time.time() - t) > time_limit:
    LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
    break  # time limit exceeded

最终的输出结果是:

(Pdb) len(output)
29
(Pdb) output[0][0]
tensor([5.56500e+02, 1.74750e+02, 6.34500e+02, 2.14250e+02, 7.48901e-02, 0.00000e+00], device='cuda:0')

输出分别代表:x1, y1, x2, y2, confidence, cls_index。

x1 y1 x2 y2 概率 类别
556 174 634 214 0.069 0

参考:yolov5 nms 源码理解_nc = prediction.shape[2] - 5 # number of classes-CSDN博客

附录完整注释代码

def non_max_suppression(
        prediction,
        conf_thres=0.25,
        iou_thres=0.45,
        classes=None,
        agnostic=False,
        multi_label=False,
        labels=(),
        max_det=300,
        nm=0,  # number of masks
):
    """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections

    Returns:
         list of detections, on (n,6) tensor per image [xyxy, conf, cls]
    """

    # Checks
    assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
    assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
    if isinstance(prediction, (list, tuple)):  # YOLOv5 model in validation model, output = (inference_out, loss_out)
        prediction = prediction[0]  # select only inference output

    """
    Pdb) pp prediction.shape
    torch.Size([29, 27783, 7])
    
    """


    device = prediction.device
    mps = 'mps' in device.type  # Apple MPS
    if mps:  # MPS not fully supported yet, convert tensors to CPU before NMS
        prediction = prediction.cpu()

    """
    (Pdb) bs
    29
    (Pdb) nc
    2
    (Pdb) xc.shape
    torch.Size([29, 27783])
    """

    bs = prediction.shape[0]  # batch size
    nc = prediction.shape[2] - nm - 5  # number of classes
    xc = prediction[..., 4] > conf_thres  # candidates

    breakpoint()
    # Settings
    # min_wh = 2  # (pixels) minimum box width and height
    max_wh = 7680  # (pixels) maximum box width and height
    max_nms = 30000  # maximum number of boxes into torchvision.ops.nms()
    time_limit = 0.5 + 0.05 * bs  # seconds to quit after
    redundant = True  # require redundant detections
    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
    merge = False  # use merge-NMS

    t = time.time()
    mi = 5 + nc  # mask start index
    output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
    for xi, x in enumerate(prediction):  # image index, image inference
        # Apply constraints
        # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0  # width-height

        """
        (Pdb) xc.shape
        torch.Size([29, 27783])
        
        (Pdb) xc[0]
        tensor([False, False, False,  ...,  True,  True,  True], device='cuda:0')\
        
        (Pdb) xc[0].shape
        torch.Size([27783])
        
        (Pdb) x.shape
        torch.Size([27783, 7])
        """
        x = x[xc[xi]]  # confidence

        """
        (Pdb) pp x.shape
        torch.Size([17788, 7])
        """


        # Cat apriori labels if autolabelling
        if labels and len(labels[xi]):
            lb = labels[xi]
            v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
            v[:, :4] = lb[:, 1:5]  # box
            v[:, 4] = 1.0  # conf
            v[range(len(lb)), lb[:, 0].long() + 5] = 1.0  # cls
            x = torch.cat((x, v), 0)

        # If none remain process next image
        if not x.shape[0]:
            continue

        """
        (Pdb) x[0]
        tensor([6.89062e+00, 1.11172e+01, 2.53906e+01, 2.80781e+01, 1.03188e-03, 1.96777e-01, 2.99805e-01], device='cuda:0', dtype=torch.float16)
        (Pdb) n
        > /home/lijinkui/Documents/code/yolov5/utils/general.py(922)non_max_suppression()
        -> box = xywh2xyxy(x[:, :4])  # center_x, center_y, width, height) to (x1, y1, x2, y2)
        (Pdb) x[0]
        tensor([6.89062e+00, 1.11172e+01, 2.53906e+01, 2.80781e+01, 1.03188e-03, 2.03013e-04, 3.09467e-04], device='cuda:0', dtype=torch.float16)
        """
        # Compute conf
        x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf

        # Box/Mask
        """
        (Pdb) pp box[0]
        tensor([-5.80469, -2.92188, 19.59375, 25.15625], device='cuda:0', dtype=torch.float16)
        """
        box = xywh2xyxy(x[:, :4])  # center_x, center_y, width, height) to (x1, y1, x2, y2)
        mask = x[:, mi:]  # zero columns if no masks

        # Detections matrix nx6 (xyxy, conf, cls)
        if multi_label:
            """
            x[:, 5:mi] > conf_thres: 大于阈值的结果
            
            i:bbox的横坐标
            j: bbox的纵坐标
            
            
            box[i]:xywh
            x[i, 5 + j, None]: 获得了所有大于阈值的bbox结果
            j[:, None]: 类别ID
            
            """
            i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T
            x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1)
            """
            (Pdb) pp x.shape
            torch.Size([17039, 6])
            """

        else:  # best class only
            conf, j = x[:, 5:mi].max(1, keepdim=True)
            x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]

        # Filter by class
        if classes is not None:
            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

        # Apply finite constraint
        # if not torch.isfinite(x).all():
        #     x = x[torch.isfinite(x).all(1)]

        # Check shape
        n = x.shape[0]  # number of boxes
        if not n:  # no boxes
            continue

        # 对置信度排序,并去掉超出范围的结果
        x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence and remove excess boxes

        # Batched NMS
        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes

        # 获取boxes 和 scores
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
        """
        (Pdb) pp boxes.shape
        torch.Size([17039, 4])
        (Pdb) pp scores.shape
        torch.Size([17039])
        """

        # 遍历score,通过IOU 筛选bbox。返回最终的bbox的index
        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
        """
        (Pdb) i.shape
        torch.Size([300])
        """

        i = i[:max_det]  # limit detections
        if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)
            # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
            iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
            weights = iou * scores[None]  # box weights
            x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
            if redundant:
                i = i[iou.sum(1) > 1]  # require redundancy

        # 获取某一个网格的最终结果
        output[xi] = x[i]
        """
        (Pdb) p output[xi].shape
        torch.Size([300, 6])
        """

        if mps:
            output[xi] = output[xi].to(device)
        if (time.time() - t) > time_limit:
            LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
            break  # time limit exceeded

    return output
posted @ 2024-07-17 09:59  金色旭光  阅读(556)  评论(0编辑  收藏  举报