长尾问题的在cv深度学习中的解决方案

本文方法,在LVIS challenge 2019比赛中获得第一名,整体ap比第二名高2.2个点,rare类别比第二名高7.1个点。在没有额外策略,仅使用本文方法的前提下,在rare类别上比baseline方法高4.1个点(加上各种策略后,再提升3.8个点)。

一、长尾问题的一般解决方案

 


在实际的视觉相关任务中,数据都存在如上图所示的长尾分布,少量类别占据了绝大多少样本,如图中Head部分,大量的类别仅有少量的样本,如图中Tail部分。解决长尾问题的方案一般分为4种:

  1. Re-sampling:主要是在训练集上实现样本平衡,如对tail中的类别样本进行过采样,或者对head类别样本进行欠采样
  2. Re-weighting:主要在训练loss中,给不同的类别的loss设置不同的权重,对tail类别loss设置更大的权重
  3. Learning strategy:有专门为解决少样本问题涉及的学习方法可以借鉴,如:meta-learning、metric learning、transfer learing。另外,还可以调整训练策略,将训练过程分为两步:第一步不区分head样本和tail样本,对模型正常训练;第二步,设置小的学习率,对第一步的模型使用各种样本平衡的策略进行finetune。
  4. 综合使用以上策略

二、本文方案

本文介绍的方案属于第3种,来自论文“Equalization Loss for Long-Tailed Object Recognition”,出自商汤。下载链接:

Paper:arxiv.org/pdf/2003.0517

Code:github.com/tztztztztz/e

1,为什么说这个方案比较简洁呢?

  1. 出发点比较简单:减少梯度反向传播时对tail样本的惩罚
  2. 仅有1个超参需要人工调节。
  3. 可以嵌入到任何模型训练中

2,以检测任务为例,修改了检测任务中分类的loss:在交叉熵loss的基础上,增加了一个权重,如下式

其中wj计算方式如下:

 

 


其中E(r)为二值,当r为前景类别时,为1,为背景类别时,为0。Tλ(fj)也为二值,当fj小于λ时为1,反之为0,λ为阈值,需要人工设定,fj为第j类样本的频率, Nj为第j类样本的图片数,N为训练集样本总数,yj为groundtruth。

3,结果

1)该方法在LVIS Challenge 2019比赛中,获得了第一名。

 


2)在没有其他balance策略的前提下,在LVIS v0.5的验证集上,应用在不同检测方法上,对检测结果有稳定3个点左右的涨幅。

 


三、code

代码部分主要用来实现wj的计算即可,如下,增加了简单的注释:

    def exclude_func(self):
        # E(r)的实现
        # instance-level weight
        bg_ind = self.n_c
        # 对背景类别置为0,非背景类别置为1
        weight = (self.gt_classes != bg_ind).float()
        weight = weight.view(self.n_i, 1).expand(self.n_i, self.n_c)
        return weight

    def threshold_func(self):
        # T(x)实现
        # class-level weight
        weight = self.pred_class_logits.new_zeros(self.n_c)
        # 对小于λ的置为1,其他为0
        weight[self.freq_info < self.lambda_] = 1
        weight = weight.view(1, self.n_c).expand(self.n_i, self.n_c)
        return weight

    def eql_loss(self):
        # eql loss的实现
        self.n_i, self.n_c = self.pred_class_logits.size()

        def expand_label(pred, gt_classes):
            target = pred.new_zeros(self.n_i, self.n_c + 1)
            target[torch.arange(self.n_i), gt_classes] = 1
            return target[:, :self.n_c]

        target = expand_label(self.pred_class_logits, self.gt_classes)
        # wj的实现
        eql_w = 1 - self.exclude_func() * self.threshold_func() * (1 - target)

        cls_loss = F.binary_cross_entropy_with_logits(self.pred_class_logits, target,
                                                      reduction='none')

        return torch.sum(cls_loss * eql_w) / self.n_i

from:https://zhuanlan.zhihu.com/p/127791648

posted @ 2022-08-02 09:12  海_纳百川  阅读(556)  评论(0编辑  收藏  举报
本站总访问量