centerloss损失函数的理解与实现
#!/usr/bin/env python # -*- coding: utf-8 -*- # @Time : 2019/12/1 22:03 # @Author : dangxusheng # @Email : dangxusheng163@163.com # @File : center_loss.py from myToolsPkgs.pytorch_helper import * from torch.autograd import Function class CenterLoss(nn.Module): """ paper: http://ydwen.github.io/papers/WenECCV16.pdf code: https://github.com/pangyupo/mxnet_center_loss pytorch code: https://blog.csdn.net/sinat_37787331/article/details/80296964 """ def __init__(self, features_dim, num_class=10, lamda=1., scale=1.0, batch_size=64): """ 初始化 :param features_dim: 特征维度 = c*h*w :param num_class: 类别数量 :param lamda centerloss的权重系数 [0,1] :param scale: center 的梯度缩放因子 :param batch_size: 批次大小 """ super(CenterLoss, self).__init__() self.lamda = lamda self.num_class = num_class self.scale = scale self.batch_size = batch_size self.feat_dim = features_dim # store the center of each class , should be ( num_class, features_dim) self.feature_centers = nn.Parameter(torch.randn([num_class, features_dim])) # self.lossfunc = CenterLossFunc.apply def forward(self, output_features, y_truth): """ 损失计算 :param output_features: conv层输出的特征, [b,c,h,w] :param y_truth: 标签值 [b,] :return: """ batch_size = y_truth.size(0) output_features = output_features.view(batch_size, -1) assert output_features.size(-1) == self.feat_dim factor = self.scale / batch_size # return self.lamda * factor * self.lossfunc(output_features, y_truth, self.feature_centers)) centers_batch = self.feature_centers.index_select(0, y_truth.long()) # [b,features_dim] diff = output_features - centers_batch loss = self.lamda * 0.5 * factor * (diff.pow(2).sum()) ######### return loss class CenterLossFunc(Function): # https://blog.csdn.net/xiewenbo/article/details/89286462 @staticmethod def forward(ctx, feat, labels, centers): ctx.save_for_backward(feat, labels, centers) centers_batch = centers.index_select(0, labels.long()) diff = feat - centers_batch return diff.pow(2).sum() / 2.0 @staticmethod def backward(ctx, grad_output): # grad_output 是最外层的梯度, 一般=1.0 feature, label, centers, superparams = ctx.saved_tensors batch_size = label.size(0) # 记录下想相同类别的索引, 求梯度时使用 label_occur = dict() for i, label_v in enumerate(label.cpu().numpy()): label_occur.setdefault(int(label_v), []).append(i) delta_center = torch.zeros_like(centers).cuda() centers_batch = centers.index_select(0, label.long()) diff = feature - centers_batch # 存储per class 的diff 总和 grad_class_sum = torch.zeros([1, centers.size(-1)]).cuda() for label_v, sample_index in label_occur.items(): grad_class_sum[:] = 0 for i in sample_index: grad_class_sum += diff[i] # 求per class的梯度均值 delta_center[label_v] = -1 * grad_class_sum / (1 + len(sample_index)) ## forced update center, 由opt执行 # centers -= alpha * grad_output * delta_center # backward输入参数和forward输出参数必须一一对应 grad_center = grad_output * delta_center grad_feat = grad_output * diff grad_label = None return grad_feat, grad_label, grad_center class Loss1(nn.Module): def __init__(self): super(Loss1, self).__init__() self.lossfunc = LossFunc.apply def forward(self, pred, truth): # return torch.abs(pred - truth) return self.lossfunc(pred, truth) class LossFunc(Function): @staticmethod def forward(ctx, pred, truth): loss = torch.abs(pred - truth) ctx.save_for_backward(pred, truth) return loss @staticmethod def backward(ctx, grad_output): pred, truth = ctx.saved_tensors print(f'grad_output={grad_output}') return grad_output, None class Loss2(nn.Module): def __init__(self): super(Loss2, self).__init__() def forward(self, pred, truth): return torch.abs(pred) if __name__ == '__main__': # test 1 import random ct = CenterLoss(2, 10, 0.1, 1., batch_size=10) y = torch.Tensor([8., 3., 8., 5., 3., 0., 6., 5., 2., 3.]) # y = torch.Tensor([random.choice(range(10)) for i in range(10)]) feat = torch.zeros(10, 2).requires_grad_() out = ct(feat, y) print(f'forward loss = {out.item()}') out.backward() print(feat.grad) print(ct.feature_centers.grad) # # test2 # x = torch.Tensor([3.]).requires_grad_() # w = torch.nn.Parameter(torch.Tensor([2.])) # y = 2 * ((5 - w * x) ** 2) # ct = Loss1() # out = ct(y, torch.Tensor([10.])) # print(out.item()) # out.backward() # print(x.grad) # print(w.grad) # # test3 # x = torch.Tensor([3.]).requires_grad_() # y = 2 * ((5 - x) ** 2) # ct = Loss2() # out = ct(y, 10) # print(out.item()) # out.backward() # print(out.grad) # print(x.grad)