对比学习simclr(附pytorch代码)
SimCLR
Chen, Ting, et al. "A simple framework for contrastive learning of visual representations." International conference on machine learning. PMLR, 2020.
解读下论文中的伪代码:
输出:
- 一个batch size 为 \(N\)的batch,\(k\in\{1,\dots,N\}\);
- 度参数\(\tau\);
- 特征提取器\(f\),投影层\(g\),数据增强集合\(\mathcal{T}\);
过程:
- 随机选择两种数据增强\(t,t'\in\mathcal{T}\)
- 第一次样本操作:图像增强:\(\tilde{x}_{2k-1}=t(x_k)\),表征提取:\(z_{2k-1}=f(\tilde{x}_{2k-1})\),投影:\(h_{2k-1}=g(z_{2k-1})\);
- 第二次样本操作:图像增强:\(\tilde{x}_{2k}=t'(x_k)\),表征提取:\(z_{2k}=f(\tilde{x}_{2k})\),投影:\(h_{2k}=g(z_{2k})\);
最后得到 \(2N\) 个结果,对于batch内,任意一个奇数索引(\(2k-1\))的样本是anchor,右侧相邻的偶数索引(\(2k\))样本是positive样本,其余\((2N-2)\)的样本都是negative样本;
- \(2N\)个样本两两计算相似度:
- 定义batch内,每个样本的loss计算:
分子部分:anchor与positive样本的相似度;分母部分:去除了anchor与anchor的样本对相似度,包含了anchor与positive的样本对,以及所有anchor与negative的样本对;
- batch的总损失定义:\[\mathcal{L}=\frac{1}{2N}\sum_{k=1}^{N}\left[\ell(2k-1,2k)+\ell(2k,2k-1)\right] \]
- 最小化损失函数\(\mathcal{L}\),更新\(f\)和\(g\);
下面给出SimCLR的PyTorch实现代码,相较于伪代码,实际实现中发生了变化。在每个batch中,anchor样本与对应的positive样本不是相邻的,而在两个batch堆叠的,例如 \(z_i\ (i<N)\)对应的正样本是\(z_{i+N}\)
import torch
from torch import nn
import torch.nn.functional as F
def cal_cosine_similarity(a, b):
feat = torch.cat((a, b), dim=0)
return F.cosine_similarity(feat.unsqueeze(1), feat.unsqueeze(0), dim=-1)
def simclr(z_i, z_j, temperature=0.5):
batch_size = z_i.shape[0]
sim_matrix = cal_cosine_similarity(z_i, z_j)
sim_ij = torch.diag(sim_matrix, batch_size) # 取anchor样本与positive样本的相似度
sim_ji = torch.diag(sim_matrix, -batch_size) # 取positive样本与anchor样本的相似度
mask = (~torch.eye(batch_size * 2, batch_size * 2, dtype=torch.bool)).float().to(z_i.device)
positives = torch.cat([sim_ij, sim_ji], dim=0)
nominator = torch.exp(positives / temperature) # 分子
denominator = mask * torch.exp(sim_matrix / temperature) # 分母
all_loss = -torch.log(nominator / torch.sum(denominator, dim=1))
return torch.mean(all_loss)
if __name__ == '__main__':
batch_size = 4
feat_dim = 16
torch.manual_seed(0)
z_i = torch.rand((batch_size, feat_dim))
z_j = torch.rand((batch_size, feat_dim))
z_i = F.normalize(z_i, dim=-1)
z_j = F.normalize(z_j, dim=-1)
loss = simclr(z_i, z_j)
print(loss)
或者使用NT-Xent loss(the normalized temperature-scaled cross entropy loss)的形式,使用交叉熵来计算损失,但是这种方式需要自己设计labels
def nt_xent_loss(z_i, z_j, temperature=0.5):
batch_size = z_i.shape[0]
sim_matrix = cal_cosine_similarity(z_i, z_j) / temperature
# 构建标签
labels = torch.arange(batch_size)
labels = torch.cat([labels + batch_size, labels], dim=0).to(z_i.device)
# 构建mask,避免将样本与自身作为负样本
mask = torch.eye(2 * batch_size, dtype=torch.bool).to(z_i.device)
sim_matrix.masked_fill_(mask, -1e9) # 在softmax计算时,exp(-1e9) ~= 0
return F.cross_entropy(sim_matrix, labels)
MoCo
He, Kaiming, et al. "Momentum contrast for unsupervised visual representation learning." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2020.
观察MoCo的损失函数:
其实与SimCLR很相似,只是在分母的\(\sum\)部分,SimCLR的item数量是2 * batch_size,而MoCo的item数量是memory bank的大小,
MoCo的损失函数同样可以使用交叉熵的形式来计算。对于labels的设置,batch内,每个anchor(\(q\))样本与正样本的点积都放在logits的第一位,所以batch内样本的labels都是0。
可以发现在很多实际的自监督学习和对比学习的任务(例如SimCLR、MoCo等),InfoNCE通常是更加常见的选择,因为它的softmax形式更适合于多负样本的设置,并能够有效地进行全局的归一化。
InfoNCE
具体来说,InfoNCE损失的公式通常写成:
在这个公式中,\(K\) 是所有可能的样本的数量,包括正样本和负样本。这里的分母是对所有样本(正样本和负样本)计算相似度的总和。为了让公式更清晰,我们可以具体区分一下:
- 查询样本 \(\mathbf{v}_i\) 和 正样本 \(\mathbf{v}_j\)的相似度被放在分子中。
- 所有负样本(比如 $ \mathbf{v}_k $)以及正样本 $ \mathbf{v}_j $ 都被包括在分母的总和中。
InfoNCE loss(Information Noise Contrastive Estimation)和NCE loss(Noise Contrastive Estimation)有密切的关系。可以说,InfoNCE是NCE loss的一个特例或变体,都是基于对比学习(contrastive learning)思想的损失函数。
NCE损失最早是为了解决概率模型估计问题而提出的,尤其是在大规模数据集上进行训练时,直接计算概率分布通常是非常昂贵的。NCE的目标是通过区分“信号”(真实数据)和“噪声”(负样本)来简化概率估计。
NCE损失的公式通常表示为:
其中:
- \(\mathbf{v}_i\) 是查询样本的特征向量;
- \(\mathbf{v}_j\) 是正样本的特征向量;
- \(\mathbf{v}_k\) 是负样本的特征向量;
- \(\sigma(x)\) 是sigmoid函数,\(\sigma(x) = \frac{1}{1 + \exp(-x)}\);
- \(K\) 是负样本的数量。
InfoNCE和NCE的关系
InfoNCE是NCE的一种变体,主要有以下几个区别和联系:
-
概率估计方式:
- NCE:采用的是基于sigmoid的概率估计方式,尝试估计正样本与负样本的区分概率。
- InfoNCE:采用的是softmax的归一化方式,它将所有负样本与正样本的相似度进行归一化处理,并最大化正样本的相似度。
NCE和InfoNCE都使用了正负样本的对比,但InfoNCE通常会显得更为直接,因为它直接对负样本的分布进行了归一化。
-
温度参数:
- 在InfoNCE中,引入了温度参数\(\tau\),来控制正负样本之间相似度的平滑程度。更大的温度参意味着更平滑的分布。
- NCE中没有显式的温度参数,计算的相似度直接依赖于内积的结果。
-
学习目标:
- 在NCE中,目标是学习区分正样本和负样本的相对概率(通常是通过sigmoid函数来实现)。
- 在InfoNCE中,目标是通过最大化正样本与查询样本之间的相似度,并最小化查询样本与负样本之间的相似度,通常是通过softmax计算相似度。