多分类的FocalLoss损失函数(torch实现三分类为例)

x = torch.rand(16, 3)
y = torch.LongTensor([0 for x in range(16)])
# 类别的样本比例为X,alpha为1-X,样本量越多受到的重视应该越小 alpha = torch.tensor([[0.43],[0.75],[0.82]]) class_num = 3 class FocalLoss(nn.Module): def __init__(self, class_num, alpha=None, gamma=2, size_average=True): super(FocalLoss, self).__init__() if alpha is None: # Variable的requires_grad参数默认不进行求导 self.alpha = Variable(torch.ones(class_num, 1)) else: if isinstance(alpha, Variable): self.alpha = alpha else: self.alpha = Variable(alpha) self.gamma = gamma self.class_num = class_num self.size_average = size_average def forward(self, inputs, targets): # batch * 3 N = inputs.size(0) C = inputs.size(1) P = F.softmax(inputs) class_mask = inputs.data.new(N, C).fill_(0) clsss_mask = Variable(class_mask) ids = targets.view(-1, 1) class_mask.scatter_(1, ids.data, 1.) if inputs.is_cuda and not self.alpha.is_cuda: self.alpha = self.alpha.cuda() alpha = self.alpha[ids.data.view(-1)] probs = (P*class_mask).sum(1).view(-1, 1) log_p = probs.log() batch_loss = -alpha*(torch.pow((1-probs), self.gamma)) * log_p if self.size_average: loss = batch_loss.mean() else: loss = batch_loss.sum() return loss loss = FocalLoss(class_num=class_num, alpha=alpha) loss(x, y)

  

posted @ 2022-07-28 19:31  麦扣  阅读(1561)  评论(0编辑  收藏  举报