Class-Balanced Loss Based on Effective Number of Samples - 2 - 代码学习
参考:https://github.com/vandit15/Class-balanced-loss-pytorch
其中的class_balanced_loss.py:
import numpy as np import torch import torch.nn.functional as F def focal_loss(labels, logits, alpha, gamma): """Compute the focal loss between `logits` and the ground truth `labels`. Focal loss = -alpha_t * (1-pt)^gamma * log(pt) where pt is the probability of being classified to the true class. pt = p (if true class), otherwise pt = 1 - p. p = sigmoid(logit). Args: labels: A float tensor of size [batch, num_classes]. logits: A float tensor of size [batch, num_classes]. alpha: A float tensor of size [batch_size] specifying per-example weight for balanced cross entropy. gamma: A float scalar modulating loss from hard and easy examples. Returns: focal_loss: A float32 scalar representing normalized total loss. """ BCLoss = F.binary_cross_entropy_with_logits(input = logits, target = labels,reduction = "none") if gamma == 0.0: modulator = 1.0 else: modulator = torch.exp(-gamma * labels * logits - gamma * torch.log(1 + torch.exp(-1.0 * logits))) loss = modulator * BCLoss weighted_loss = alpha * loss focal_loss = torch.sum(weighted_loss) focal_loss /= torch.sum(labels) return focal_loss def CB_loss(labels, logits, samples_per_cls, no_of_classes, loss_type, beta, gamma): """Compute the Class Balanced Loss between `logits` and the ground truth `labels`. Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits) where Loss is one of the standard losses used for Neural Networks. Args: labels: A int tensor of size [batch]. logits: A float tensor of size [batch, no_of_classes]. samples_per_cls: A python list of size [no_of_classes]. no_of_classes: total number of classes. int loss_type: string. One of "sigmoid", "focal", "softmax". beta: float. Hyperparameter for Class balanced loss. gamma: float. Hyperparameter for Focal loss. Returns: cb_loss: A float tensor representing class balanced loss """ effective_num = 1.0 - np.power(beta, samples_per_cls) weights = (1.0 - beta) / np.array(effective_num) weights = weights / np.sum(weights) * no_of_classes labels_one_hot = F.one_hot(labels, no_of_classes).float() weights = torch.tensor(weights).float() weights = weights.unsqueeze(0) weights = weights.repeat(labels_one_hot.shape[0],1) * labels_one_hot weights = weights.sum(1) weights = weights.unsqueeze(1) weights = weights.repeat(1,no_of_classes) if loss_type == "focal": cb_loss = focal_loss(labels_one_hot, logits, weights, gamma) elif loss_type == "sigmoid": cb_loss = F.binary_cross_entropy_with_logits(input = logits,target = labels_one_hot, weights = weights) elif loss_type == "softmax": pred = logits.softmax(dim = 1) cb_loss = F.binary_cross_entropy(input = pred, target = labels_one_hot, weight = weights) return cb_loss if __name__ == '__main__': no_of_classes = 5 logits = torch.rand(10,no_of_classes).float() labels = torch.randint(0,no_of_classes, size = (10,)) beta = 0.9999 gamma = 2.0 samples_per_cls = [2,3,1,2,2] loss_type = "focal" cb_loss = CB_loss(labels, logits, samples_per_cls, no_of_classes,loss_type, beta, gamma) print(cb_loss)
添加注释和输出的版本:
#coding:utf-8 import numpy as np import torch import torch.nn.functional as F def focal_loss(labels, logits, alpha, gamma): """Compute the focal loss between `logits` and the ground truth `labels`. Focal loss = -alpha_t * (1-pt)^gamma * log(pt) where pt is the probability of being classified to the true class. pt = p (if true class), otherwise pt = 1 - p. p = sigmoid(logit). Args: labels: A float tensor of size [batch, num_classes]. logits: A float tensor of size [batch, num_classes]. alpha: A float tensor of size [batch_size] specifying per-example weight for balanced cross entropy. gamma: A float scalar modulating loss from hard and easy examples. Returns: focal_loss: A float32 scalar representing normalized total loss. """ BCLoss = F.binary_cross_entropy_with_logits(input = logits, target = labels,reduction = "none") if gamma == 0.0: modulator = 1.0 else: modulator = torch.exp(-gamma * labels * logits - gamma * torch.log(1 + torch.exp(-1.0 * logits))) loss = modulator * BCLoss weighted_loss = alpha * loss # 然后求损失的均值mean() focal_loss = torch.sum(weighted_loss) focal_loss /= torch.sum(labels) return focal_loss def CB_loss(labels, logits, samples_per_cls, no_of_classes, loss_type, beta, gamma): """Compute the Class Balanced Loss between `logits` and the ground truth `labels`. Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits) where Loss is one of the standard losses used for Neural Networks. Args: labels: A int tensor of size [batch]. logits: A float tensor of size [batch, no_of_classes]. samples_per_cls: A python list of size [no_of_classes]. no_of_classes: total number of classes. int loss_type: string. One of "sigmoid", "focal", "softmax". beta: float. Hyperparameter for Class balanced loss. gamma: float. Hyperparameter for Focal loss. Returns: cb_loss: A float tensor representing class balanced loss """ # 下面的操作用来计算((1-beta)/(1-beta^n)),即使用在损失函数中的weight effective_num = 1.0 - np.power(beta, samples_per_cls) print('effective_num shape: ', effective_num.shape) print(effective_num) weights = (1.0 - beta) / np.array(effective_num) print('weights shape : ', weights.shape) print(weights) weights = weights / np.sum(weights) * no_of_classes #归一化 print('weights shape : ', weights.shape) print(weights) labels_one_hot = F.one_hot(labels, no_of_classes).float() print('labels_one_hot shape: ', labels_one_hot.shape) print(labels_one_hot) print('-'*50) weights = torch.tensor(weights).float() weights = weights.unsqueeze(0) print('unsqueeze weights shape : ', weights.shape) print(weights) #labels_one_hot.shape[0]得到样本数量,weight.repeat(,1)函数中的1表示weight对应位置的大小不变,所以是对行repeat labels_one_hot.shape[0]变 print(weights.repeat(labels_one_hot.shape[0],1)) #([1,5])变成([10,5]) weights = weights.repeat(labels_one_hot.shape[0],1) * labels_one_hot print('repeat weights shape : ', weights.shape) print(weights) weights = weights.sum(1) #按dim=1相加,只留下dim=0,即得到每个样本的weight print('sum weights shape : ', weights.shape) print(weights) weights = weights.unsqueeze(1) print('unsqueeze weights shape : ', weights.shape) print(weights) weights = weights.repeat(1,no_of_classes) #这个就是按列相乘,([10,1])变成([10,5]) print('repeat weights shape : ', weights.shape) print(weights) print('-'*50) if loss_type == "focal": cb_loss = focal_loss(labels_one_hot, logits, weights, gamma) elif loss_type == "sigmoid": cb_loss = F.binary_cross_entropy_with_logits(input = logits,target = labels_one_hot, weights = weights) elif loss_type == "softmax": pred = logits.softmax(dim = 1) cb_loss = F.binary_cross_entropy(input = pred, target = labels_one_hot, weight = weights) return cb_loss if __name__ == '__main__': no_of_classes = 5 #10个样本,5个类别 logits = torch.rand(10,no_of_classes).float() # 预测10个样本分别是5个类别的概率 print('logits shape : ', logits.shape) print(logits) labels = torch.randint(0,no_of_classes, size = (10,)) #10个样本的实际类别 print('labels shape : ', labels.shape) print(labels) beta = 0.9999 #参数设置 gamma = 2.0 #参数设置 samples_per_cls = [2,3,1,2,2] # 每个类别的样本数 loss_type = "focal" cb_loss = CB_loss(labels, logits, samples_per_cls, no_of_classes,loss_type, beta, gamma) print(cb_loss)
返回:
(deeplearning) bogon:work_gender_age wanghui$ python test_delete.py logits shape : torch.Size([10, 5]) tensor([[0.1505, 0.9621, 0.8622, 0.0237, 0.0270], [0.6218, 0.2745, 0.3015, 0.1501, 0.1728], [0.3590, 0.1760, 0.0807, 0.7440, 0.6973], [0.9401, 0.7118, 0.1725, 0.1843, 0.3226], [0.4655, 0.8319, 0.4336, 0.8718, 0.5842], [0.9423, 0.3339, 0.1081, 0.4718, 0.4329], [0.5122, 0.7010, 0.1736, 0.5903, 0.0712], [0.6442, 0.1365, 0.1391, 0.8278, 0.5986], [0.1245, 0.5662, 0.9571, 0.8515, 0.9883], [0.4654, 0.8924, 0.0224, 0.9056, 0.4517]]) labels shape : torch.Size([10]) tensor([0, 3, 4, 2, 2, 1, 4, 0, 1, 2]) effective_num shape: (5,) [1.99990000e-04 2.99970001e-04 1.00000000e-04 1.99990000e-04 1.99990000e-04] weights shape : (5,) [0.500025 0.33336667 1. 0.500025 0.500025 ] weights shape : (5,) [0.88236332 0.58827163 1.76463841 0.88236332 0.88236332] labels_one_hot shape: torch.Size([10, 5]) tensor([[1., 0., 0., 0., 0.], [0., 0., 0., 1., 0.], [0., 0., 0., 0., 1.], [0., 0., 1., 0., 0.], [0., 0., 1., 0., 0.], [0., 1., 0., 0., 0.], [0., 0., 0., 0., 1.], [1., 0., 0., 0., 0.], [0., 1., 0., 0., 0.], [0., 0., 1., 0., 0.]]) -------------------------------------------------- unsqueeze weights shape : torch.Size([1, 5]) tensor([[0.8824, 0.5883, 1.7646, 0.8824, 0.8824]]) tensor([[0.8824, 0.5883, 1.7646, 0.8824, 0.8824], [0.8824, 0.5883, 1.7646, 0.8824, 0.8824], [0.8824, 0.5883, 1.7646, 0.8824, 0.8824], [0.8824, 0.5883, 1.7646, 0.8824, 0.8824], [0.8824, 0.5883, 1.7646, 0.8824, 0.8824], [0.8824, 0.5883, 1.7646, 0.8824, 0.8824], [0.8824, 0.5883, 1.7646, 0.8824, 0.8824], [0.8824, 0.5883, 1.7646, 0.8824, 0.8824], [0.8824, 0.5883, 1.7646, 0.8824, 0.8824], [0.8824, 0.5883, 1.7646, 0.8824, 0.8824]]) repeat weights shape : torch.Size([10, 5]) tensor([[0.8824, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.8824, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.8824], [0.0000, 0.0000, 1.7646, 0.0000, 0.0000], [0.0000, 0.0000, 1.7646, 0.0000, 0.0000], [0.0000, 0.5883, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.8824], [0.8824, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.5883, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 1.7646, 0.0000, 0.0000]]) sum weights shape : torch.Size([10]) tensor([0.8824, 0.8824, 0.8824, 1.7646, 1.7646, 0.5883, 0.8824, 0.8824, 0.5883, 1.7646]) unsqueeze weights shape : torch.Size([10, 1]) tensor([[0.8824], [0.8824], [0.8824], [1.7646], [1.7646], [0.5883], [0.8824], [0.8824], [0.5883], [1.7646]]) repeat weights shape : torch.Size([10, 5]) tensor([[0.8824, 0.8824, 0.8824, 0.8824, 0.8824], [0.8824, 0.8824, 0.8824, 0.8824, 0.8824], [0.8824, 0.8824, 0.8824, 0.8824, 0.8824], [1.7646, 1.7646, 1.7646, 1.7646, 1.7646], [1.7646, 1.7646, 1.7646, 1.7646, 1.7646], [0.5883, 0.5883, 0.5883, 0.5883, 0.5883], [0.8824, 0.8824, 0.8824, 0.8824, 0.8824], [0.8824, 0.8824, 0.8824, 0.8824, 0.8824], [0.5883, 0.5883, 0.5883, 0.5883, 0.5883], [1.7646, 1.7646, 1.7646, 1.7646, 1.7646]]) -------------------------------------------------- tensor(1.9583)
可见在代码中能够使用二分类求损失主要是因为将labels转换成了ont-hot格式
labels_one_hot = F.one_hot(labels, no_of_classes).float()
主要比较复杂的就是focal loss的实现:
1)BCLoss = F.binary_cross_entropy_with_logits()
2) modulator
3)weight 即传进来的参数alpha