Unsupervised Semantic Segmentation by Distilling Feature Correspondences
概
本文介绍了一种无监督的语义分割方法, 只需在 frozen backbone 上训练一个 head 就可以起到语义分割的作用. 感觉比较好利用.
流程
-
backbone \(\mathcal{N}\) 将两个图片 \(x, x'\) 映射为特征 \(f \in \mathbb{R}^{CHW}, g \in \mathbb{R}^{CIJ}\);
-
计算二者间的 feature correspondence:
\[\tag{1} F_{hwij} := \sum_c \frac{f_{chw} \cdot g_{cij}}{\|f_{hw}\| \cdot \| g_{ij}\|}, \]即 spatial element-wise 的 cosine similarity. 倘若 \(f =g\), 则可以衡量不同区域间的 correspondence;
-
segmentation head \(\mathcal{S}\) 将 \(f, g\) 分别映射为
\[\tag{2} s = \mathcal{S}(f) \in \mathbb{R}^{KHW}, \: t = \mathcal{S}(g) \in \mathbb{R}^{KIJ}, \]类似 (1) 计算二者间的 feature correspondence \(S_{hwij}\);
-
得到如下的损失函数:
\[\tag{3} \mathcal{L}_{simple-corr}(x, x', b) := - \sum_{hwij} (F_{hwij} - b) S_{hwij}, \]其中 \(b\) 为一个超参数. 注意到, 关于 \(\mathcal{S}\) 最小化上式, 有
\[S_{hwij} \uparrow \: \text{ if } F_{hwij} > b, \\ S_{hwij} \downarrow \: \text{ if } F_{hwij} < b. \\ \]故合适的 \(b\) 会促使 \(\mathcal{S}_{hwij}\) 准确度量一致性;
-
但是 (3) 在训练的时候并不稳定, 作者先将 \(F_{hwij}\) 进行中心化, 即
\[F_{hwij}^{SC} := F_{hwij} - \frac{1}{IJ} \sum_{i'j'} F_{hwi'j'}. \]然后用如下的损失进行替代:
\[\tag{4} \mathcal{L}_{corr}(x, x', b) := - \frac{1}{HWIJ} \sum_{hwij} (F_{hwij}^{SC} - b) \max(S_{hwij}, 0); \] -
最后的损失为
\[\tag{5} \mathcal{L} =\lambda_{self}\mathcal{L}_{corr}(x, x, b_{self}) +\lambda_{knn}\mathcal{L}_{corr}(x, x^{knn}, b_{self}) +\lambda_{rand}\mathcal{L}_{corr}(x, x^{rand}, b_{self}). \]其中 \(x\) 和其本身 \(x\) 或类似的 (positive) 样本 \(x^{knn}\) (通过 KNN 选取的) 的损失主要是为了学习正向的 \(S_{hwij} \uparrow\) 的信息, 而 \(x\) 和随机的样本 \(x^{rand}\) 之间的损失则是为了更多的提供 \(S_{hwij} \downarrow\) 等负向的排斥的信息的学习;
-
为了给学习得到的特征图 \(S(f)\) 进行语义分割, 可以采用如下两种方式:
- Linear Probe: 以线性网络和部分监督信息, 凭借交叉熵损失即可训练;
- Clustering.
-
通过 connected Conditional Random Field (CRF) 对语义分割进行微调.
注: Clustering 我看代码是按照如下方式训练的:
-
随机初始化一些类别中心 \(\mu_1, \mu_2, \cdots, \mu_K\);
-
对于每个 batch \(\{x_1, \cdots, x_n\}\) 得到 \(\{s_1, \cdots, s_n\}\), 每个 \(s_i \in \mathbb{R}^{CH_iW_i}\);
-
对每个 \(x_i, \mu_k\) 进行标准化 \(x_i = \frac{x_i}{\|x_i\|}, \mu_k = \frac{\mu_k}{\|\mu_k\|}\);
-
对于每个 channel 计算内积:
\[z_{ikc} = s_{hw}^T \mu_k, \: i \in [n], k \in [K], c \in [C]; \] -
对于每组 \(z_{ik} \in \mathbb{R}^C\) 找到最大的:
\[m_{ik} =\arg\max_c z_{ikc}; \] -
然后通过如下损失优化:
\[\min_{\mu_k} \: - \frac{1}{NK} \sum_{ik} z_{ikm_{ik}}. \]个人感觉, 这就是一种特殊的梯度版的 K-means 算法来更新 \(\mu\).
代码
[official]