torch.einsum()

讲解
对比学习论文中出现:

# compute logits
# Einstein sum is more intuitive
# positive logits: Nx1
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
# negative logits: NxK
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
posted @ 2022-02-22 20:27  zae  阅读(647)  评论(0编辑  收藏  举报