物体检测序列之一:ap, map

准确率(Precision),也叫正确预测率(positive predictive value),在模式识别、信息检索、机器学习等研究应用领域,准确率用来衡量模型预测的结果中相关或者正确的比例。而召回率(recall),也叫敏感度(sensitivity),即模型预测的结果中相关或正确的数量占样本中实际相关或正确的数量的比例。一般在计算机视觉领域物体分类检测任务中,检测出的物体轮廓框如果类别和ground truth的类别相同,并且两者之间IoU大于一个阈值(一般为0.5),那么该预测值是正确的。如果模型预测出数据集中有N个horse这一类的物体,并给出其轮廓框,而实际场景中有T个horse标签的ground truth。在2000个预测值中,有些是错误的,即和ground truth的IoU小于阈值,又或者IoU大于阈值,但是该ground truth已经有和其相关的预测值了,此时该预测值均为假正例(false positive,FP),有些和ground truth的IoU大于阈值,并且为唯一和ground truth相关的预测值,那么该预测值为真正例(true positive, TP)。那么此时准确率的计算为precision=TP/(TP+FP)。而召回率的计算则为recall=TP/T.

准确率和召回率是一对矛盾的度量尺度,一般来说,准确率高时,召回率往往很低,而召回率高时,准确率往往偏低。例如在物体分类检测任务中,一组预测轮廓框有两个指标,一个是分类的score,一组是轮廓框的坐标位置。若希望将找出更多的该类的预测轮廓框,则需要将分类score的阈值降低,输出更多的结果,此时必然会有些预测值为假,造成准确率下降,但是召回率会得以提高。但是如果想准确率高,则需要提高分类score,将分类得分很低的预测值丢弃,那么可能会漏掉很多和ground truth的IoU大于阈值的预测结果,造成召回率降低。

在很多情况中,需要对模型的预测结果进行排序,在物体分类检测中,一般是按照分类score进行排序,排在前面的预测结果一般被认为“最可能”为真正例的预测结果,排在后面的则是模型认为“可能性低”的真正例。按照此顺序逐个把预测结果进行分析,每个预测值均可以得到一个准确率和召回率,以召回率为横轴,准确率为纵轴作图,就得到了准确率-召回率曲线,一般称为precision-recall曲线。Precision-recall曲线图直观地显示了模型在数据集上的准确率、召回率详细分布情况。在进行模型选择时,直观上可以看到,如果模型A的precision-recall曲线将模型B的precision-recall曲线完全包住,则模型A的性能优于B。而若模型A的precision-recall曲线和模型C的precision-recall曲线有交点,则此时无法直观上判定模型的优劣,此时一般用precision-recall曲线下面积的大小来进行比较,这个面积在物体分类检测任务中是另外一项衡量模型优劣的指标ap(average precision)值。而如果数据集有多类,那么每个类计算得到ap后再平均所有类的ap即使mAP. mAP是评估模型在数据集上各个类上的整体表现性能,也是以各个数据集为基础举办各种物体检测竞赛的主要评估指标。

计算AP和mAP的代码,来自PASCAL VOC物体检测竞赛的devkit,是和数据集一起公开的测试评估代码,供参赛者提交代码前进行初步的模型性能评估。可以看到代码先计算了预测值和ground truth的IoU,再而计算tp, fp, precision和recall.而后根据precision和recall计算precision-recall曲线下的面积,即AP,并将每类的precision-recall曲线绘制。下图给出根据precision-recall曲线计算ap的结果。图中给出PASCAL VOC 2007数据检测结果其中两类的precision-recall曲线,从图中可以看出,car这类的曲线弧度比较大,因此ap值比较大。而chair这类的曲线则是直线下降的,因此ap值比较小。

 工程上,我们一般要求recall很高,而precision也很高。那么上图中,对chair这一类的检测就无法达到这个目标,在recall很高时,precision就会很低,也就造成很高的误检率。这种时候,为了降低误检率,即tp,会将ap图从中间截断,根据工程需要,选择一个阈值,如图所示,蓝色红色绿色三条线分别对应不同的recall和precison,如果你想要大的precision,选择蓝色线处,如果想要大recall,选择绿色线处截断ap图。

比如在某次训练的检测器中,对类别car的检测结果如下,因为需要降低误检,此时选择了f1=2*precision*recall/(precision+recall)最高时的阈值,可以看到此时ap并不是很高,说明该检测器还有待提升。

至于map,那就是mean ap的缩写。即各个类的ap做个均值即可,那么就是各个类都高了,均值就高了,map也就大了。

 附录是matlab代码,python版本的现在也一大堆,因为我最开始做检测是一句一句解读matlab代码开始的,所以附上,算作纪念:

function [rec,prec,ap] = VOCevaldet(VOCopts,id,cls,draw)

% load test set
[gtids,t]=textread(sprintf(VOCopts.imgsetpath,VOCopts.testset),'%s %d');

% load ground truth objects
tic;
npos=0;
gt(length(gtids))=struct('BB',[],'diff',[],'det',[]);
for i=1:length(gtids)
    % display progress
    if toc>1
        fprintf('%s: pr: load: %d/%d\n',cls,i,length(gtids));
        drawnow;
        tic;
    end
    
    % read annotation
    rec=PASreadrecord(sprintf(VOCopts.annopath,gtids{i}));
    
    % extract objects of class
    clsinds=strmatch(cls,{rec.objects(:).class},'exact');
    gt(i).BB=cat(1,rec.objects(clsinds).bbox)';
    gt(i).diff=[rec.objects(clsinds).difficult];
    gt(i).det=false(length(clsinds),1);
    npos=npos+sum(~gt(i).diff);
end

% load results
[ids,confidence,b1,b2,b3,b4]=textread(sprintf(VOCopts.detrespath,id,cls),'%s %f %f %f %f %f');
BB=[b1 b2 b3 b4]';

% sort detections by decreasing confidence
[sc,si]=sort(-confidence);
ids=ids(si);
BB=BB(:,si);

% assign detections to ground truth objects
nd=length(confidence);
tp=zeros(nd,1);
fp=zeros(nd,1);
tic;
for d=1:nd   
    % find ground truth image
    i=strmatch(ids{d},gtids,'exact');   
    % assign detection to ground truth object if any
    bb=BB(:,d);          ovmax=-inf;
    for j=1:size(gt(i).BB,2)
        bbgt=gt(i).BB(:,j);
        bi=[max(bb(1),bbgt(1)) ; max(bb(2),bbgt(2)) ; min(bb(3),bbgt(3)) ; min(bb(4),bbgt(4))];
        iw=bi(3)-bi(1)+1;             ih=bi(4)-bi(2)+1;
        if iw>0 & ih>0                
            % compute overlap as area of intersection / area of union
            ua=(bb(3)-bb(1)+1)*(bb(4)-bb(2)+1)+ (bbgt(3)-bbgt(1)+1)*(bbgt(4)-bbgt(2)+1)- iw*ih;            
            ov=iw*ih/ua;
            if ov>ovmax
                ovmax=ov;       jmax=j;
            end
        end
    end
    % assign detection as true positive/don't care/false positive
    if ovmax>=VOCopts.minoverlap
        if ~gt(i).diff(jmax)
            if ~gt(i).det(jmax)
                tp(d)=1;            % true positive
		gt(i).det(jmax)=true;
            else
                fp(d)=1;            % false positive (multiple detection)
            end
        end
    else
        fp(d)=1;                    % false positive
    end
end
% compute precision/recall
fp=cumsum(fp);     tp=cumsum(tp);      rec=tp/npos;   prec=tp./(fp+tp);
% compute average precision
ap=0;
for t=0:0.1:1
    p=max(prec(rec>=t));
    if isempty(p)
        p=0;
    end
    ap=ap+p/11;
end    
if draw
    % plot precision/recall
    plot(rec,prec,'-');
    grid;
    xlabel 'recall'
    ylabel 'precision'
    title(sprintf('class: %s, subset: %s, AP = %.3f',cls,VOCopts.testset,ap));
end

  

posted @ 2023-06-17 15:37  caoeryingzi  阅读(70)  评论(0编辑  收藏  举报