def iou(a, b): ax, ay, ar, ab = a bx, by, br, bb = b cross_x = max(ax, bx) cross_y = max(ay, by) cross_r = min(ar, br) cross_b = min(ab, bb) cross_w = max(0, (cross_r - cross_x) + 1) cross_h = max(0, (cross_b - cross_y) + 1) cross_area = cross_w * cross_h union = (ar - ax + 1) * (ab - ay + 1) + (br - bx + 1) * (bb - by + 1) - cross_area return cross_area / union def nms(detectiones, threshold, confidence_index=-1): detectiones = sorted(detectiones, key=lambda x: x[confidence_index], reverse=True) flags = [True] * len(detectiones) keep = [] for i in range(len(detectiones)): if not flags[i]: continue keep.append(detectiones[i]) for j in range(i+1, len(detectiones)): if iou(detectiones[i][:4], detectiones[j][:4]) > threshold: flags[j] = False return np.vstack(keep)