【其他算法】配对关系转组+非极大值抑制

def get_group_from_pair(pair_list):
    """
    功能:根据成对的关系,获得group
    输入:对关系list,如[[1, 3], [2, 3], [4, 5], [3, 6], [6, 8]]
    输出:组关系list,如[[1, 2, 3, 6, 8], [4, 5]]
    方法:dfs, 先获得相邻关系,再串联
    """
    def get_pair_relation(pair_list):
        relation = {}
        for pair in pair_list:
            p1, p2 = pair
            relation.setdefault(p1, []).append(p2)
            relation.setdefault(p2, []).append(p1)
        return relation
    relation = get_pair_relation(pair_list)
    merged_set = set()
    group = []
    for head in relation.keys():
        if head in merged_set: continue  # 已经被合并
        group_item = []
        que = [head]
        while que:
            item = que.pop(0)
            group_item.append(item)
            merged_set.add(item)
            for item_ in relation.get(item, []):
                if item_ not in que and item_ not in merged_set:
                    que.append(item_)
        group.append(group_item)
    return group

def nms(bbox_list, iou_thr=0.2):
    """
    功能:给定可能有压盖的带分数的矩形框list,输出压盖小的list
    输入:[x1, y1, x2, y2, score]list,
    如[[3, 8, 9, 3, 0.9], [5, 7, 8, 4, 0.7], [7, 5, 10, 1, 0.8]];
    输出:剩余的[x1, y1, x2, y2, score]list,
    方法:按分数排序,被压盖的直接打掉
    """
    def cal_iou(bbox1, bbox2):
        x1, y1, x2, y2 = bbox1[:4]
        a1, b1, a2, b2 = bbox2[:4]
        p1 = min(x2, a2) - max(x1, a1)
        p2 = min(y2, b2) - max(y1, b1)
        inter = 0
        if p1 > 0 and p2 > 0:
            inter = p1 * p2
        union = (x2 - x1) * (y2 - y1) + (a2 - a1) * (b2 - b1) - inter
        iou = inter / union
        return iou

    bbox_list = sorted(bbox_list, key=lambda x: x[-1], reverse=True)
    num = len(bbox_list)
    reserve_flag = [True for _ in range(num)]
    res = []
    for i in range(num):
        if reserve_flag[i] is False: continue  # 已经被抑制
        res.append(bbox_list[i])
        for j in range(i + 1, num):
            iou = cal_iou(bbox_list[i], bbox_list[j])
            if iou > iou_thr:
                reserve_flag[j] = False
    return res

if __name__ == '__main__':
    pair_list = [[1, 3], [2, 3], [4, 5], [3, 6], [6, 8]]
    print(get_group_from_pair(pair_list))
    bbox_list = [[3, 3, 9, 8, 0.9], [5, 4, 8, 7, 0.7], [7, 1, 10, 5, 0.8]]
    # bbox_list = [[5, 7, 9, 15, 0.5], [12, 8, 16, 12, 0.7], [14, 5, 20, 13, 0.9]]
    print(nms(bbox_list))
posted @ 2022-12-26 15:08  我若成风者  阅读(17)  评论(0编辑  收藏  举报