MNS部分:non_max_suppression()

一、以下是验证(Validate)部分

1、NMS

# NMS
# targets的xyxy
targets[:, 2:] *= torch.tensor((width, height, width, height), device=device)  # to pixels
lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else []  # for autolabelling
with dt[2]:
    # 非极大值抑制
    preds = non_max_suppression(preds,
                                conf_thres,
                                iou_thres,
                                labels=lb,
                                multi_label=True,
                                agnostic=single_cls,
                                max_det=max_det)

二、non_max_suppression()

def non_max_suppression(
        prediction,
        conf_thres=0.25,
        iou_thres=0.45,
        classes=None,
        agnostic=False,
        multi_label=False,
        labels=(),
        max_det=300,
        nm=0,  # number of masks
):
    """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections

    Returns:
         list of detections, on (n,6) tensor per image [xyxy, conf, cls]
    """

    # Checks
    assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
    assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
    if isinstance(prediction, (list, tuple)):  # YOLOv5 model in validation model, output = (inference_out, loss_out)
        prediction = prediction[0]  # select only inference output

    device = prediction.device
    mps = 'mps' in device.type  # Apple MPS
    if mps:  # MPS not fully supported yet, convert tensors to CPU before NMS
        prediction = prediction.cpu()
    bs = prediction.shape[0]  # batch size
    nc = prediction.shape[2] - nm - 5  # number of classes
    # 目标置信度阈值筛选
    xc = prediction[..., 4] > conf_thres  # candidates

    # Settings
    # min_wh = 2  # (pixels) minimum box width and height
    max_wh = 7680  # (pixels) maximum box width and height
    max_nms = 30000  # maximum number of boxes into torchvision.ops.nms()
    time_limit = 0.5 + 0.05 * bs  # seconds to quit after
    redundant = True  # require redundant detections
    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
    merge = False  # use merge-NMS

    t = time.time()
    mi = 5 + nc  # mask start index
    output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
    for xi, x in enumerate(prediction):  # image index, image inference
        # Apply constraints
        # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0  # width-height
        # 目标置信度阈值筛选
        x = x[xc[xi]]  # confidence

        # Cat apriori labels if autolabelling
        if labels and len(labels[xi]):
            lb = labels[xi]
            v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
            v[:, :4] = lb[:, 1:5]  # box
            v[:, 4] = 1.0  # conf
            v[range(len(lb)), lb[:, 0].long() + 5] = 1.0  # cls
            x = torch.cat((x, v), 0)

        # If none remain process next image
        if not x.shape[0]:
            continue

        # Compute conf
        # 80个类别置信度=80个类别置信度*目标置信度
        x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf

        # Box/Mask
        box = xywh2xyxy(x[:, :4])  # center_x, center_y, width, height) to (x1, y1, x2, y2)
        mask = x[:, mi:]  # zero columns if no masks

        # Detections matrix nx6 (xyxy, conf, cls)
        if multi_label:
            # x[246,80]类别预测大于阈值conf_thres的索引(行、列)
            i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T
            # 阈值conf_thres筛选后的结果,生成新的x[N,6],其中6为检测框(xyxy)、类别置信度、列号(类别)
            x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1)
        else:  # best class only
            conf, j = x[:, 5:mi].max(1, keepdim=True)
            x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]

        # Filter by class
        if classes is not None:
            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

        # Apply finite constraint
        # if not torch.isfinite(x).all():
        #     x = x[torch.isfinite(x).all(1)]

        # Check shape
        n = x.shape[0]  # number of boxes
        if not n:  # no boxes
            continue
        # 将x根据类别置信度重新排列,并选取前最多30000个
        x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence and remove excess boxes

        # Batched NMS
        # NMS筛选
        # c=7680*类别(0-80)
        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
        # 将不同类别的检测框根据其类别偏移!主要是为了将不同类别的检测框分开处理!不过也存在相同类别的检测框重叠的情况->softnms
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
        # 返回满足要求的300个检测框!
        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
        i = i[:max_det]  # limit detections
        if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)
            # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
            iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
            weights = iou * scores[None]  # box weights
            x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
            if redundant:
                i = i[iou.sum(1) > 1]  # require redundancy

        output[xi] = x[i]
        if mps:
            output[xi] = output[xi].to(device)
        if (time.time() - t) > time_limit:
            LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
            break  # time limit exceeded
    # output[N,6],其中6:xyxy、类别置信度、预测类别
    return output

1、分类类别置信度=类别置信度 * 目标置信度

# Compute conf
# 80个类别置信度*目标置信度
x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf

 2、NMS筛选:

(1)为了避免将不同类别的检测框nms,将每个类别检测框偏移c(类别索引 * 7680)

(2)相同类别的不同检测框nms,不需要区分的么?这就用到softnms!参考链接

# Batched NMS
# NMS筛选
# c=7680*类别(0-80)
c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
# 将不同类别的检测框根据其类别偏移!主要是为了将不同类别的检测框分开处理!不过也存在相同类别的检测框重叠的情况啊!
boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
# 返回满足要求的300个检测框!
i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS

 

posted @ 2023-03-12 17:27  kuaqi  阅读(219)  评论(0编辑  收藏  举报