CLIP损失函数的理解
参考资料:
[CLIP huggingface源码:CLIPModel]
这篇文章首先展示CLIP损失函数的两种底层实现代码,然后聊一聊自己的理解。
说实话念硕士的时候没有接触过CLIP这个东西,来实习之后发现这个多模态的模型使用非常广泛,设计理念也是看后惊为天人。加上最近有探究任务研究CLIP,BLIP这些,遂决心把这个模型弄懂。参考资料1已经把CLIP的设计思想,原理,甚至是底层实现给讲清楚了,但是当我读到训练的损失函数那一段的时候还是产生了很大的疑问:作者说有两种方式来计算损失函数,一种较为简单,一种较为复杂。较为复杂的损失函数实现如下:
def forward(self, batch): # Getting Image and Text Features image_features = self.image_encoder(batch["image"]) text_features = self.text_encoder( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"] ) # Getting Image and Text Embeddings (with same dimension) image_embeddings = self.image_projection(image_features) text_embeddings = self.text_projection(text_features) # Calculating the Loss logits = (text_embeddings @ image_embeddings.T) / self.temperature images_similarity = image_embeddings @ image_embeddings.T texts_similarity = text_embeddings @ text_embeddings.T targets = F.softmax( (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1 ) texts_loss = cross_entropy(logits, targets, reduction='none') images_loss = cross_entropy(logits.T, targets.T, reduction='none') loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size) return loss.mean()
其中Cross_entropy也是作者自己实现的,看上去就是logsoftmax加上NLLloss:
def cross_entropy(preds, targets, reduction='none'): log_softmax = nn.LogSoftmax(dim=-1) loss = (-targets * log_softmax(preds)).sum(1) if reduction == "none": return loss elif reduction == "mean": return loss.mean()
较为简单的损失函数的实现则是这样:nn.CrossEntropyLoss()(logits, torch.arange(batch_size))
作者在下面进行了分析,我看完分析之后觉得... ... 作者的语气好像是在说这种较为简单的损失函数是有误的,在数据集中有同一张图片的多个相似caption的时候会明显犯错。那么,较为复杂的损失函数就是正确的了。以上是Tutorial里作者的实现,较为权威的另一种实现是huggingface团队Transformer库里的源码。由于CLIP模型的高度可定制性,huggingface团队实现了一个基类,也就是CLIPModel部分。并在需要训练的时候把loss设置为forward函数的第一个返回值,我们来看一下他们的实现:
image_embeds = vision_outputs[1] image_embeds = self.visual_projection(image_embeds) text_embeds = text_outputs[1] text_embeds = self.text_projection(text_embeds) # normalized features image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) # cosine similarity as logits logit_scale = self.logit_scale.exp() logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale logits_per_image = logits_per_text.t() loss = None if return_loss: loss = clip_loss(logits_per_text)
其中,clip_loss的实现如下:
# contrastive loss function, adapted from # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) def clip_loss(similarity: torch.Tensor) -> torch.Tensor: caption_loss = contrastive_loss(similarity) image_loss = contrastive_loss(similarity.t()) return (caption_loss + image_loss) / 2.0
一开始的归一化比较好理解,logit_scale是一个超参数也好理解。最难理解的就是logits_per_text和logits_per_image这两个互为转置的矩阵。写这篇文章的时候我只能说自己弄懂了7分,原论文中有这么一段话:While standard image models jointly train an image feature extractor and a linear classifier to predict some label, CLIP jointly trains an image encoder and a text encoder to predict the correct pairings of a batch of (image, text) training examples. 即CLIP是学习(image, text)图文对之间的正确匹配的。这个正确匹配有两个对称的方面:1)对于每一个caption,和它吻合的图片得到label 1,和它不吻合的图片得到label 0。(这个对应于caption_loss)2)对于每一个image,和它吻合的caption得到label 1,和它不吻合的caption得到label 0。(这个对应于image_loss)而将两个loss相加除以2,得到的损失函数就同时考虑了两个方面了。如果一个模型在这两个方面都做得好,那么大概率是能够成功学习到correct pairings of a batch of (image, text) 的。