[CVPR2022]DASO Distribution-Aware Semantics-Oriented Pseudo-label for Imbalanced Semi-Supervised Learning
问题的背景设置:半监督学习下,labeled data和unlabeled data的分布不同,且存在类别不平衡。文章提出了一种新的伪标签生成方法:DistributionAware Semantics-Oriented (DASO) Pseudo-label。首先生成语义伪标签和线性为标签,然后将它们混合实现互补。
另外作者的方法不需要估计无标签数据的先验、也不需要假设labeled/unlabeled data 分布一致的假设。
- 线性分类器(linear classifier):通过fc layer实现。
- 语义分类器(semantic classifier):通过衡量表征间的相似性(如prototype)实现。
作者方法的提出基于观察:基于语义的分类器分配的语义伪标签倾向于少数类,这与线性分类器得到的线性伪标签表现相反。
对于两种伪标签的混合权重也需要调整。具体来说,需要根据当前伪标签分布,调整语义伪标签权重,使得减少对线性伪标签的偏置。
DASO伪标签框架
线性伪标签:\(\hat{p}\);语义伪标签:\(\hat{q}\);两者最终的结合标签:\(\hat{p}'\)。
线性伪标签直接通过linear + softmax 获取;而语义伪标签需要先计算每个类的prototype。
对于每个类的prototype,定义为\(\mathbf{C}=\{c_k\}_{k=1}^K\),并为每个类准备一个先进先出的记忆队列\(\mathbf{Q}=\{Q_k\}_{k=1}^K\),每个类的队列长度为\(|Q_{k}|\)。每个类的prototype为对列内的特征均值,并在每个训练的step,push新的特征,当队列满时,pop最早的特征。
为了避免数据不平衡对prototype的影响,提出两种策略:1、对于每个类的队列上限保持一致;2、使用exponential moving average (EMA)更新提取用于prototype特征的model:$\theta'\leftarrow\rho\theta'+(1-\rho)\theta $降低模型的更新速度(此处的模型是额外引入的模型,仅用于提取特征,不同于框架其他地方使用的模型)。
语义伪标签计算公式为
其中sim表示余弦相似度。最后的伪标签为:
最后的综合伪标签为:
其中\(v = \{v_{k}\}_{k=1}^{K}\)为分布感知权重,防止\(\hat{p}\)过于偏向头部类,\(v_{k} = \frac{1}{\max_{k} \hat{m}_{k}^{1/T_{\mathrm{dist}}}} \left(\hat{m}_{k}^{1/T_{\mathrm{dist}}}\right)\),\(\hat{m}\)表示归一化后的伪标签分布。当线性伪标签\(\hat{p}\)预测为头部类,如果偏置较大,就会有更多的语义伪标签\(\hat{q}\)被混合进来。
无标签数据的预测:\(p=f(\mathcal{A}_s(u))\),这里\(\mathcal{A}_s\)表示强图像增强,并将预测结果与最后的伪标签计算损失\(\mathcal{L}_{u}=\Phi_u(\hat{p}, p)=\mathbb{I}\left(\max_kp_k\geq\tau\right)\mathcal{H}\left(\hat{p},p\right)\),\(\mathcal{H}\)表示交叉熵。
为了保证表征更平衡,作者模仿了FixMatch的一致性正则:
这里\(q^{(s)}\)是通过强图像增强后提取表征使用语义相似度分类器+softmax 后的得到的结果。最后总的损失为:
这里\(\mathcal{L}_u\)使用的标签为混合伪标签,而\(\mathcal{L}_u, \mathcal{L}_{cls}\)损失函数的定义由相应的半监督学习框架决定。作者的伪标签生成框架以及\(\mathcal{L}_{align}\)适合于其他半监督学习框架。
参考文献
- Oh, Youngtaek, Dong-Jin Kim, and In So Kweon. "Daso: Distribution-aware semantics-oriented pseudo-label for imbalanced semi-supervised learning." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022.