CVPR 2022 | 解耦知识蒸馏

CVPR 2022 | 解耦知识蒸馏

  近年来,SOTA蒸馏方法多基于网络中间层的深层特征,而基于logit的KD(知识蒸馏)则被忽略。

  旷视科技、早稻田大学、清华大学的研究者研究了传统KD的局限之处后,给logit蒸馏的研究提出了新的方法。

  作者证明传统KD的损失函数存在耦合,该耦合限制了logit蒸馏的性能。

  作者提出将耦合的损失函数解耦成TCKD(target class KD)和NCKD(non-target class KD)的加权和。

  这样,1)释放了NCKD部分对模型性能的影响; 2) 解耦了TCKD和NCKD,可以分别配置权重,平衡二者对模型性能的影响。

  通过一系列的实验,作者也证明了本文的DKD(Decoupled Knowledge Distillatio) 方法可以在更高的训练效率下达到相当于甚至优于SOTA蒸馏方法的效果。

论文地址:
https://arxiv.org/abs/2203.08679
代码地址:
https://github.com/megvii-research/mdistiller

image

一、本文方法

  按理说,logits在语义水平上要比深层特征更高级,那么logit蒸馏应该是可以达到与特征蒸馏相当的水平,但是现实却是:logit蒸馏虽然在算力和存储消耗上有优势,但是性能却远远比不上特征蒸馏。

  是不是有什么“暗黑力量”封禁了logit的“洪荒之力”?

  那还得从KD的机制出发,反思一下,有没有什么不合理的地方。

  分类预测可以被看成两个层次:1)目标类+所有非目标类的二元预测; 2)每个非目标类的多元预测。

  传统的KD训练机制如下图:

image

  作者通过详细的推导(原文中),将损失函数整理为:

image

  其中T,S分别表示教师和学生网络;b表示目标类的二元概率; $p_t$ 表示目标类概率;$\hat{p}$ 表示非目标类的多元概率。

   从上式可以看出,KD损失函数被重构成两部分:

   $KL(bT||bS)$ 代表教师和学生网络的目标类二元概率的相似性,命名为TCKD;$KL(\hat{p}T||\hat{p}S)$ 代表教师和学生网络非目标类多元概率的相似性,命名为NCKD。所以KD损失函数可以表示为:

image

  可以看出,NCKD的权重被 $p_t^T$ 耦合

  这种耦合是否正是抑制NCKD“洪荒之力”的“暗黑力量” ?

  作者通过实验研究了TCKD和NCKD两个部分的作用,以及耦合的影响:

image

  上表可以看出一个有趣的现象:只用NCKD一个部分时的性能甚至比baseline还要好,足以说明NCKD部分对于整个网络的重要性,“洪荒之力”确实被耽误了!

image

  上表也证明了NCKD在预测情况较好的情况下作用更大,应该有更大的权重才对,但是从上边的损失函数公式可以看出,现实是NCKD的权重却随着 $p_t^T$的增大而减小

  但是到此处还看不出TCKD的明显优势,作者从上述公式推导过程推测,TCKD部分表征的是识别训练样本的难度,并通过实验证明了这一推测。

  作者通过对数据集加入AutoAugment、噪音以及用更具有挑战性的数据集等方法来提升训练难度,结果如下:

image

image

  这一部分实验证明了TCKD在训练难度提升情况下的有效作用。

  研究到此,将损失公式重构解耦势在必行。

  解耦的好处有两点:

  1)将NCKD从$(1-p_t^T)$ 中解耦,释放其“洪荒之力”;

  2)将TCKD和NCKD两个部分解耦,分别配置权重,平衡二者的作用。

  作者将新的logit蒸馏命名为DKD(Decoupled Knowledge Distillatio)

  损失函数如下:

image

  这里的$\alpha$和$\beta$ 是可以自行调整的超参数,用于平衡TCKD和NCKD两个部分的重要性。

二、实验分析

1 解耦对于性能的提升

image

2 图像分类-CIFAR100

image

image

3 图像分类-ImageNet

image

4 目标检测-MS-COCO

  logits无法提供目标定位的知识,所以单独将DKD用于目标检测效果不佳,但是DKD可以辅助用于目标定位的模型,提升其性能。

image

5 DKD vs SOTA --训练效率

image

6 更大的教师网络

  更大的教师网络与小的学生网络之间的容量差增大,反而会使蒸馏效果变差。本文DKD方法通过解耦,可以缓解这个问题。

image

7 特征转化性能

image

8 可视化

image

三、总结

  本文给了logit蒸馏一种新的解释,将传统的KD损失重构成两部分--TCKD和NCKD。两部分的重要性都得到证明,耦合的KD方式限制了知识迁移的有效性的灵活性也被证明。
  本文提出的解耦蒸馏(DKD)被用于不同数据集,用于完成图像分类和目标检测任务,效果超越了传统logitKD,达到甚至超越特征蒸馏的SOTA方法。

作者:万国码
https://www.cnblogs.com/ljdsbfj/p/16294900.html

posted @ 2022-06-08 22:33  万国码aaa  阅读(611)  评论(0编辑  收藏  举报