对比学习量化评价
在超球面上通过对齐和一致实现理解对比表示学习 —— 论文阅读笔记
两个对比损失最关键的要素:
- 正例对特征的对齐(就是找最接近的正例对)。
- 超球面特征分布的均匀分布(可以保存最多的信息
torch 版本代码:
# bsz : batch size (number of positive pairs)
# d : latent dim
# x : Tensor, shape=[bsz, d]
# latents for one side of positive pairs
# y : Tensor, shape=[bsz, d]
# latents for the other side of positive pairs
# lam : hyperparameter balancing the two losses
def lalign(x, y, alpha=2):
return (x - y).norm(dim=1).pow(alpha).mean()
def lunif(x, t=2):
sq_pdist = torch.pdist(x, p=2).pow(2)
return sq_pdist.mul(-t).exp().mean().log()
loss = lalign(x, y) + lam * (lunif(x) + lunif(y)) / 2
tensorflow 版本:
def lalign(x, y, alpha=2):
"""
x: [bs, d] latents for one side of positive pairs
y: [bs,d] latents for the other side of positive pairs
"""
# 第二范数
return tf.reduce_mean(tf.pow(tf.norm(x - y, axis=1), alpha))
def lunif(x, t=2):
"""
x: [bs, d]
"""
batch_size = tf.shape(x)[0]
# 实现torch.pdist
x=tf.cast(x, tf.float32)
pdist_matrix = tf.norm(x[:, None]-x, axis=2)
bool_mask = tf.cast(1-tf.linalg.band_part(tf.ones((batch_size,batch_size)),-1,0), bool) # 右上对角线
pdist = pdist_matrix[bool_mask]
sq_pdist = tf.pow(pdist, 2)
return tf.math.log(tf.reduce_mean(tf.exp(-t*sq_pdist)))