Tensorflow版Faster RCNN源码解析(TFFRCNN) (05) nms_wrapper.py

本blog为github上CharlesShang/TFFRCNN版源码解析系列代码笔记

---------------个人学习笔记---------------

----------------本文作者疆--------------

------点击此处链接至博客园原文------

 

1.def nms(dets,thresh,force_cpu=False)

def nms(dets, thresh, force_cpu=False):
    """Dispatch to either CPU or GPU NMS implementations."""
    if dets.shape[0] == 0:
        return []
    # 默认USE_GPU_NMS = True
    if cfg.USE_GPU_NMS and not force_cpu:
        return gpu_nms(dets, thresh, device_id=cfg.GPU_ID)  # gpu_nms.so
    else:
        return cpu_nms(dets, thresh)

选择以GPU或CPU模式执行nms,实际是.so动态链接对象执行,其中,dets是某类box和scor构成的数组,shape为(None,5),被nms_wrapper(...)函数调用

2.def nms_wrapper(scores,boxes,threshold = 0.7,class_sets = None)

# train_model()调用时未传入class_sets参数
def nms_wrapper(scores, boxes, threshold = 0.7, class_sets = None):  # box得分必须大于0.7
    # scores:R * num_class
    # boxes: R * (4 * num_class)
    # return: a list of K-1 dicts, no background, each is {'class': classname, 'dets': None | [[x1,y1,x2,y2,score],...]} 
    num_class = scores.shape[1] if class_sets is None else len(class_sets)
    assert num_class * 4 == boxes.shape[1],\
        'Detection scores and boxes dont match'
    class_sets = ['class_' + str(i) for i in range(0, num_class)] if class_sets is None else class_sets
    # class_sets = [class_0, class_1, class_2...]
    res = []
    # 针对各类,构造该类的dets(含box坐标和score)
    for ind, cls in enumerate(class_sets[1:]):
        ind += 1  # skip background
        cls_boxes = boxes[:, 4*ind: 4*(ind+1)]
        cls_scores = scores[:, ind]
        dets = np.hstack((cls_boxes, cls_scores[:, np.newaxis])).astype(np.float32)
        keep = nms(dets, thresh=0.3)  # nms阈值为0.3
        dets = dets[keep, :]  # 类内nms处理
        dets = dets[np.where(dets[:, 4] > threshold)]  # score必须超过阈值0.7
        r = {}
        if dets.shape[0] > 0:
            r['class'], r['dets'] = cls, dets
        else:
            r['class'], r['dets'] = cls, None
        res.append(r)
        # res为列表,每个元素为某类别标号和dets构成的字典
    return res

最后返回的res为列表,列表中每个元素为字典,每个字典含某类标号如(class_1,class_2...)和dets(该类box的坐标和对应score,针对各类类内先nms后取得分超过阈值的box)

posted @ 2019-07-30 10:41  JiangJ~  阅读(628)  评论(0编辑  收藏  举报