Loading

对比学习量化评价

在超球面上通过对齐和一致实现理解对比表示学习 —— 论文阅读笔记

两个对比损失最关键的要素:

  • 正例对特征的对齐(就是找最接近的正例对)。
  • 超球面特征分布的均匀分布(可以保存最多的信息

image

torch 版本代码:

# bsz : batch size (number of positive pairs)
# d : latent dim
# x : Tensor, shape=[bsz, d]
# latents for one side of positive pairs
# y : Tensor, shape=[bsz, d]
# latents for the other side of positive pairs
# lam : hyperparameter balancing the two losses
def lalign(x, y, alpha=2):
    return (x - y).norm(dim=1).pow(alpha).mean()
def lunif(x, t=2):
    sq_pdist = torch.pdist(x, p=2).pow(2)
    return sq_pdist.mul(-t).exp().mean().log()
loss = lalign(x, y) + lam * (lunif(x) + lunif(y)) / 2

tensorflow 版本:

def lalign(x, y, alpha=2):
    """
    x: [bs, d] latents for one side of positive pairs
    y: [bs,d] latents for the other side of positive pairs
    """
    # 第二范数
    return tf.reduce_mean(tf.pow(tf.norm(x - y, axis=1), alpha))

def lunif(x, t=2):
    """
    x: [bs, d]
    """
    batch_size = tf.shape(x)[0]
    # 实现torch.pdist
    x=tf.cast(x, tf.float32)
    pdist_matrix = tf.norm(x[:, None]-x, axis=2)
    bool_mask = tf.cast(1-tf.linalg.band_part(tf.ones((batch_size,batch_size)),-1,0), bool) # 右上对角线
    pdist = pdist_matrix[bool_mask]
    sq_pdist = tf.pow(pdist, 2)
    return tf.math.log(tf.reduce_mean(tf.exp(-t*sq_pdist)))
posted @ 2022-06-28 15:54  戴墨镜的长颈鹿  阅读(96)  评论(0编辑  收藏  举报