[CVPR2022]DASO Distribution-Aware Semantics-Oriented Pseudo-label for Imbalanced Semi-Supervised Learning

问题的背景设置:半监督学习下,labeled data和unlabeled data的分布不同,且存在类别不平衡。文章提出了一种新的伪标签生成方法:DistributionAware Semantics-Oriented (DASO) Pseudo-label。首先生成语义伪标签和线性为标签,然后将它们混合实现互补。

另外作者的方法不需要估计无标签数据的先验、也不需要假设labeled/unlabeled data 分布一致的假设。

  • 线性分类器(linear classifier):通过fc layer实现。
  • 语义分类器(semantic classifier):通过衡量表征间的相似性(如prototype)实现。

作者方法的提出基于观察:基于语义的分类器分配的语义伪标签倾向于少数类,这与线性分类器得到的线性伪标签表现相反。

对于两种伪标签的混合权重也需要调整。具体来说,需要根据当前伪标签分布,调整语义伪标签权重,使得减少对线性伪标签的偏置。

DASO伪标签框架

线性伪标签:\(\hat{p}\);语义伪标签:\(\hat{q}\);两者最终的结合标签:\(\hat{p}'\)

线性伪标签直接通过linear + softmax 获取;而语义伪标签需要先计算每个类的prototype。

对于每个类的prototype,定义为\(\mathbf{C}=\{c_k\}_{k=1}^K\),并为每个类准备一个先进先出的记忆队列\(\mathbf{Q}=\{Q_k\}_{k=1}^K\),每个类的队列长度为\(|Q_{k}|\)。每个类的prototype为对列内的特征均值,并在每个训练的step,push新的特征,当队列满时,pop最早的特征。

为了避免数据不平衡对prototype的影响,提出两种策略:1、对于每个类的队列上限保持一致;2、使用exponential moving average (EMA)更新提取用于prototype特征的model:$\theta'\leftarrow\rho\theta'+(1-\rho)\theta $降低模型的更新速度(此处的模型是额外引入的模型,仅用于提取特征,不同于框架其他地方使用的模型)。

语义伪标签计算公式为

\[q={\rm softmax}({\rm sim}(z,\mathbf{C}) / T_{\mathrm{proto}}) \]

其中sim表示余弦相似度。最后的伪标签为:

\[\hat{p}'=(1-v_{k'}) \hat{p}+v_{k'}\hat{q} \]

最后的综合伪标签为:

\[\hat{p}'=(1-v_{k'}) \hat{p}+v_{k'}\hat{q} \]

其中\(v = \{v_{k}\}_{k=1}^{K}\)为分布感知权重,防止\(\hat{p}\)过于偏向头部类,\(v_{k} = \frac{1}{\max_{k} \hat{m}_{k}^{1/T_{\mathrm{dist}}}} \left(\hat{m}_{k}^{1/T_{\mathrm{dist}}}\right)\)\(\hat{m}\)表示归一化后的伪标签分布。当线性伪标签\(\hat{p}\)预测为头部类,如果偏置较大,就会有更多的语义伪标签\(\hat{q}\)被混合进来。

无标签数据的预测:\(p=f(\mathcal{A}_s(u))\),这里\(\mathcal{A}_s\)表示强图像增强,并将预测结果与最后的伪标签计算损失\(\mathcal{L}_{u}=\Phi_u(\hat{p}, p)=\mathbb{I}\left(\max_kp_k\geq\tau\right)\mathcal{H}\left(\hat{p},p\right)\)\(\mathcal{H}\)表示交叉熵。

为了保证表征更平衡,作者模仿了FixMatch的一致性正则:

\[\mathcal{L}_{\mathrm{align}}=\mathcal{H}\left(\hat{q}, q^{(s)}\right) \]

这里\(q^{(s)}\)是通过强图像增强后提取表征使用语义相似度分类器+softmax 后的得到的结果。最后总的损失为:

\[\mathcal{L}_{\mathrm{DASO}}=\mathcal{L}_{\mathrm{cls}}+\lambda_u\mathcal{L}_u+\lambda_{\mathrm{align}}\mathcal{L}_{\mathrm{align}} \]

这里\(\mathcal{L}_u\)使用的标签为混合伪标签,而\(\mathcal{L}_u, \mathcal{L}_{cls}\)损失函数的定义由相应的半监督学习框架决定。作者的伪标签生成框架以及\(\mathcal{L}_{align}\)适合于其他半监督学习框架。

参考文献

  1. Oh, Youngtaek, Dong-Jin Kim, and In So Kweon. "Daso: Distribution-aware semantics-oriented pseudo-label for imbalanced semi-supervised learning." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022.
posted @ 2024-08-09 20:32  October-  阅读(48)  评论(0编辑  收藏  举报