DETR系列之DN-DETR

DN-DETR

CVPR 2022 的一篇文章

一、Introduction

之前许多工作对 detr 的encoder或是decoder结构进行了改进,以期改善收敛慢的现象。本文作者从另一个角度(训练方法的角度)分析和解决了detr收敛慢的问题。

第一次提出了全新的去噪训练(DeNoising training)解决了DETR decoder在训练过程中二分图匹配 (bipartite graph matching)不稳定的问题,可以让模型收敛速度翻倍,并对检测结果带来显著提升(+1.9AP)。该方法简易实用,可以广泛运用到各种DETR模型当中,以微小的训练代价带来显著提升。

二、Model

(一)二分图匹配的不稳定性导致训练速度慢

我的理解是:匈牙利匹配算法会根据cost metric将两个匹配程度最高的框,作为一对匹配,在此基础上计算损失并更新模型,而在训练过程中模型的更新会使得产生的预测框发生变化,而这种变化会导致cost metric的变化,进而很容易导致匹配结果与之前的匹配结果不同(例如之前是预测框a匹配gtbox 6,模型训练更新会向着使预测框a与gtbox6匹配程度更高的方向去调整,但这种调整不仅会影响a与gtbox6的匹配程度,还会无意中影响到a与其他gtbox的匹配程度,所以有可能会产生更新后预测框a与gtbox 10的匹配程度高于与gtbox6的匹配程度的情况,这种情况下a又变为与gtbox10匹配了),即二分图匹配的不稳定性。而这种不稳定性会使得loss值发生波动,使得优化目标具有不连续性,阻碍模型的收敛

针对这种不稳定性,作者设计了评判标准进行量化实验验证:

image-20220405145706655

(二)DN-DETR

为了解决二分图匹配的不稳定问题,作者提出了一种新的训练方式,就是在原有基础上增加一个训练任务,来提高训练过程的稳定性。该工作在DAB-DETR基础上进行展开。

在DAB-DETR中,cross-attention的输入query由两部分:learnable anchors(anchor box参数,包括x y w h),decoder embeddings(学习目标的内容信息)。

在DN-DETR中,为了更好的发挥新增加的训练任务denoising task的作用,将decoder embedding替换为了带有目标标签信息的class label embedding,并且附加了一个指示器indicator,用来区分是denoising task还是matching task。

执行matching task时,除了输入的class label embedding是unknown class之外,其他的部分都与之前的DAB-DETR相同;

执行denoising task时,输入的learnable anchors是将gtbox信息进行中心点偏移或者边框缩放得到的,class label embedding是将真是标签按照一定比例进行随机翻转得到的。在denoising task中,由于事先知道输入的信息对应于哪一个gtbox,所以在计算损失时不需要进行二分图匹配,就不存在匹配不稳定的问题。

在denoising task的干预下,训练的不稳定性降低。

原文描述如下:

To address this problem, we propose a novel training method by introducing a query denoising task to help stabilize bipartite graph matching in the training process.

Our solution is to feed noised ground truth bounding boxes as noised queries together with learnable anchor queries into Transformer decoders. Both kinds of queries have the same input format of (x, y, w, h) and can be fed into Transformer decoders simultaneously.

For noised queries, we perform a denoising task to reconstruct their corresponding ground truth boxes. For other learnable anchor queries, we use the same training loss including bipartite matching as in the vanilla DETR.

另外,由于denoising task和matching task时同时进行的,所以在内部计算时可能会出现一些信息的交互,使得matching task部分获知了denoising task部分输入的信息(由于是从gtbox加噪来的,所以带有gtbox信息),也就是提前知道了“答案”,这会损害matching部分的学习(最终预测时只保留matching部分,所以它学习到的能力才是最关键的)。因此作者设计了一个attention mask来阻止这种信息的交互。

下面是总体图:

三、Experiments

posted @ 2022-04-08 14:56  彼岸的客人  阅读(1647)  评论(1编辑  收藏  举报