MMDetection Sigmoid Focal Loss解析

Focal Loss[1]是一种用来处理单阶段目标检测器训练过程中出现的正负、难易样本不平衡问题的方法。关于Focal Loss,[2]中已经讲的很详细了,这篇博客主要是记录和补充一些细节。

1.两阶段怎么处理样本数量不平衡的问题

  • 两阶段级联的检测方法: 因为物体可能出现在图片中的任意位置,这些位置构成的集合过于庞大,因此在第一阶段使用RPN将可能性大的位置先筛选出来。这一步会过滤掉很多易于检测的负样本(easy negatives)
  • 有偏差地进行采样: 对第一阶段剩下的样本,再按照例如正负样本1:3的比例进行采样,这种方法相当于隐式地实现了Focal Loss中的\(\alpha\)参数。

2.Sigmoid Focal Loss

论文中没有用一般多分类任务采取的softmax loss,而是使用了多标签分类中的sigmoid loss(即逐个判断属于每个类别的概率,不要求所有概率的和为1,一个检测框可以属于多个类别),原因是sigmoid的形式训练过程中会更稳定。因此RetinaNet分类subnet输出的通道数是 KA 而不是 (K+1)A(K为类别数,A为每个cell铺的anchor数)。

3.Focal Loss 代码分析

MMDetection[3]中实现的Focal Loss如下:

# This method is only for debugging
def py_sigmoid_focal_loss(pred,
                          target,
                          weight=None,
                          gamma=2.0,
                          alpha=0.25,
                          reduction='mean',
                          avg_factor=None):
    
    pred_sigmoid = pred.sigmoid()

    target = target.type_as(pred)
    
    pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
    
    focal_weight = (alpha * target + (1 - alpha) *
                    (1 - target)) * pt.pow(gamma)
    
    loss = F.binary_cross_entropy_with_logits(
        pred, target, reduction='none') * focal_weight
    
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    
    return loss

论文中给出的公式是:\(FL(p_t)=-\alpha_t(1-p_t)^\gamma\log(p_t)\),下面分析代码的逻辑:

首先给出两个公式:

\[p_t= \begin{cases} p, &t=1 \\ 1-p, &t=0 \end{cases} \]

\(p_t\)为预测值,表示属于 t 类别的概率,可统一表示为:\(p_t=p*t+(1-p)*(1-t)\)

\[\alpha_t= \begin{cases} \alpha, &t=1 \\ 1-\alpha, &t=0 \end{cases} \]

\(\alpha_t\)为权重参数,表示属于 t 类别的权重,可统一表示为:\(\alpha_t=\alpha*t+(1-\alpha)*(1-t)\)

带入得:\(FL(p_t)=-\alpha_t(1-(p*t+(1-p)*(1-t)))^\gamma\log(p_t)\)

\[=\underbrace{\alpha_t(\overbrace{p(1-t)+t(1-p)}^{pt})^\gamma}_{focal\_weight}*\underbrace{-\log(p_t)}_{cross\ entropy} \]

举一个例子,设

pred=[0.1, 0.3, 0.8, 0.1, 0.1]
target=[0, 0, 1, 0, 0]
α=0.25
γ=2
# sigmoid value of pred
pred_sigmoid=[0.5250, 0.5744, 0.6900, 0.5250, 0.5250]

直接根据论文的公式计算loss可得:\(FL(pred,target)=\underbrace{-0.75*0.5250^2*\log(1-0.5250)*3 - 0.75*0.5744^2*\log(1-0.5744)}_{negatives}\underbrace{-0.25*(1-0.6900)^2*log(0.6900)}_{positives}=0.1364*5\)
与上面的py_sigmoid_focal_loss函数(计算的是平均值)计算结果相同。

4.关于\(\alpha_t\)

因为Focal Loss的本意是将loss集中在正样本上,所以我一直以为α=0.25是负样本的权重,但是调试代码时发现0.25其实是乘在正样本上了。这是一个比较矛盾的地方,因为检测任务中负样本比正样本要多很多,而且大部分都是论文中提到过的easy negatives。自然的想法当然是降低这部分loss的权重,让训练朝着更有意义的方向进行,所以我们给正样本的α设大一点,负样本是1-α,因此会比较小。直到看到[2]评论区的讨论,个人觉得还是比较有说服力的:

重新去查了下focal loss论文,在gamma=0时,alpha=0.75效果更好,但当gamma=2时,alpha=0.25效果更好,个人的解释为负样本(IOU<=0.5)虽然远比正样本(IOU>0.5)要多,但大部分为IOU很小(如<0.1)以至于在gamma作用后某种程度上贡献较大损失的负样本甚至比正样本还要少,所以alpha=0.25要反过来重新平衡负正样本。

大意就是负样本大部分都是容易检测的,用于平衡难易样本地γ取2时,负样本的loss会过度地衰减,因此需要α进行反向地平衡。我没有用代码验证过,不过这些都是超参,研究的意义也不大,定性地分析应该足够。

5.TODO

mmdetection的py_sigmoid_focal_loss实现其实有一点问题,不能直接替换sigmoid_focal_loss,不过最近已经修改过了,这部分以后有机会再细说。

参考

  1. https://arxiv.org/pdf/1708.02002.pdf
  2. https://zhuanlan.zhihu.com/p/80594704
  3. https://github.com/open-mmlab/mmdetection
  4. https://mingming97.github.io/2019/03/29/mmdetection retinanet
posted @ 2021-01-27 20:45  backtosouth  阅读(4968)  评论(0编辑  收藏  举报