这里给出论文的SupContrast: Supervised Contrastive Learning的损失函数Tensorflow版本,代码改自:https://github.com/HobbitLong/SupContrast
损失文件losses.py
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | """ Author: Yonglong Tian (yonglong@mit.edu) Date: May 07, 2020 """ from __future__ import print_function import tensorflow as tf class SupConLoss( object ): """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. It also supports the unsupervised contrastive loss in SimCLR""" def __init__( self , temperature = 0.07 , contrast_mode = 'all' , base_temperature = 0.07 ): super (SupConLoss, self ).__init__() self .temperature = temperature self .contrast_mode = contrast_mode self .base_temperature = base_temperature def forward( self , features, labels = None , mask = None ): """Compute loss for model. If both `labels` and `mask` are None, it degenerates to SimCLR unsupervised loss: https://arxiv.org/pdf/2002.05709.pdf Args: features: hidden vector of shape [bsz, n_views, ...]. labels: ground truth of shape [bsz]. mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j has the same class as sample i. Can be asymmetric. Returns: A loss scalar. """ sizes = features.get_shape().as_list() if len (sizes) < 3 : raise ValueError( '`features` needs to be [bsz, n_views, ...],' 'at least 3 dimensions are required' ) if len (sizes) > 3 : features = tf.reshape(features, [tf.shape(features)[ 0 ], tf.shape(features)[ 1 ], - 1 ]) batch_size = tf.shape(features)[ 0 ] if labels is not None and mask is not None : raise ValueError( 'Cannot define both `labels` and `mask`' ) elif labels is None and mask is None : mask = tf.eye(batch_size, dtype = tf.float32) elif labels is not None : labels = tf.reshape(labels, [ - 1 , 1 ]) mask = tf.cast(tf.equal(labels, tf.transpose(labels,[ 1 , 0 ])),dtype = tf.float32) else : mask = tf.cast(mask,dtype = tf.float32) # contrast_count = tf.shape(features)[1] contrast_count = features.get_shape().as_list()[ 1 ] contrast_feature = tf.concat(tf.unstack(features,axis = 1 ),axis = 0 ) if self .contrast_mode = = 'one' : anchor_feature = features[:, 0 ] anchor_count = 1 elif self .contrast_mode = = 'all' : anchor_feature = contrast_feature anchor_count = contrast_count else : raise ValueError( 'Unknown mode: {}' . format ( self .contrast_mode)) # compute logits anchor_dot_contrast = tf.matmul(anchor_feature, contrast_feature, transpose_b = True ) / self .temperature # for numerical stability logits_max = tf.reduce_max(anchor_dot_contrast, axis = 1 , keep_dims = True ) logits = anchor_dot_contrast - tf.stop_gradient(logits_max) # tile mask mask = tf.tile(mask,[anchor_count, contrast_count]) # mask-out self-contrast cases logits_mask = tf.ones_like(mask) - tf.one_hot(tf.reshape(tf. range (batch_size * anchor_count),[ - 1 ]), depth = batch_size * anchor_count) mask = mask * logits_mask # compute log_prob exp_logits = tf.exp(logits) * logits_mask log_prob = logits - tf.log(tf.reduce_sum(exp_logits,axis = 1 , keep_dims = True )) # compute mean of log-likelihood over positive mean_log_prob_pos = tf.reduce_sum(mask * log_prob, axis = 1 ) / tf.reduce_sum(mask, axis = 1 ) # loss loss = - ( self .temperature / self .base_temperature) * mean_log_prob_pos loss = tf.reduce_mean(tf.reshape(loss, [anchor_count, batch_size])) # loss = tf.reduce_mean(loss) return loss |
测试:
import tensorflow as tf
import losses
import os
os.environ["CUDA_VISIBLE_DEVICES"]='0'
loss = losses.SupConLoss()
X = tf.random_uniform([10,2,5])
y = tf.random_uniform([10],minval=0, maxval=2, dtype=tf.int32)
sess = tf.Session()
print(sess.run(loss.forward(X,y)))
输出:8.23587
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧