多分类的FocalLoss损失函数(torch实现三分类为例)
x = torch.rand(16, 3)
y = torch.LongTensor([0 for x in range(16)])
# 类别的样本比例为X,alpha为1-X,样本量越多受到的重视应该越小
alpha = torch.tensor([[0.43],[0.75],[0.82]])
class_num = 3
class FocalLoss(nn.Module):
def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
super(FocalLoss, self).__init__()
if alpha is None:
# Variable的requires_grad参数默认不进行求导
self.alpha = Variable(torch.ones(class_num, 1))
else:
if isinstance(alpha, Variable):
self.alpha = alpha
else:
self.alpha = Variable(alpha)
self.gamma = gamma
self.class_num = class_num
self.size_average = size_average
def forward(self, inputs, targets):
# batch * 3
N = inputs.size(0)
C = inputs.size(1)
P = F.softmax(inputs)
class_mask = inputs.data.new(N, C).fill_(0)
clsss_mask = Variable(class_mask)
ids = targets.view(-1, 1)
class_mask.scatter_(1, ids.data, 1.)
if inputs.is_cuda and not self.alpha.is_cuda:
self.alpha = self.alpha.cuda()
alpha = self.alpha[ids.data.view(-1)]
probs = (P*class_mask).sum(1).view(-1, 1)
log_p = probs.log()
batch_loss = -alpha*(torch.pow((1-probs), self.gamma)) * log_p
if self.size_average:
loss = batch_loss.mean()
else:
loss = batch_loss.sum()
return loss
loss = FocalLoss(class_num=class_num, alpha=alpha)
loss(x, y)