ruijiege

  博客园 :: 首页 :: 博问 :: 闪存 :: 新随笔 :: 联系 :: 订阅 订阅 :: 管理 ::
class GIoULoss(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, A, B):
        num_bbox = A.size(0) * A.size(2)
        ax, ay, ar, ab = A[:, 0], A[:, 1], A[:, 2], A[:, 3]
        bx, by, br, bb = B[:, 0], B[:, 1], B[:, 2], B[:, 3]
        xmax = torch.min(ar, br)
        ymax = torch.min(ab, bb)
        xmin = torch.max(ax, bx)
        ymin = torch.max(ay, by)
        cross_width = (xmax - xmin + 1).clamp(0)
        cross_height = (ymax - ymin + 1).clamp(0)
        cross = cross_width * cross_height
        union = (ar - ax + 1) * (ab - ay + 1) + (br - bx + 1) * (bb - by + 1) - cross
        iou = cross / union
        cxmin = torch.min(ax, bx)
        cymin = torch.min(ay, by)
        cxmax = torch.max(ar, br)
        cymax = torch.max(ab, bb)
        c = (cxmax - cxmin + 1) * (cymax - cymin + 1)
        return (1 - (iou - (c - union) / c)).sum() / num_bbox

 

posted on 2022-11-02 17:52  哦哟这个怎么搞  阅读(71)  评论(0编辑  收藏  举报