centerloss损失函数的理解与实现

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2019/12/1 22:03
# @Author  : dangxusheng
# @Email   : dangxusheng163@163.com
# @File    : center_loss.py

from myToolsPkgs.pytorch_helper import *
from torch.autograd import Function


class CenterLoss(nn.Module):
    """
    paper: http://ydwen.github.io/papers/WenECCV16.pdf
    code:  https://github.com/pangyupo/mxnet_center_loss
    pytorch code: https://blog.csdn.net/sinat_37787331/article/details/80296964
    """

    def __init__(self, features_dim, num_class=10, lamda=1., scale=1.0, batch_size=64):
        """
        初始化
        :param features_dim: 特征维度 = c*h*w
        :param num_class: 类别数量
        :param lamda   centerloss的权重系数 [0,1]
        :param scale:  center 的梯度缩放因子
        :param batch_size:  批次大小
        """
        super(CenterLoss, self).__init__()
        self.lamda = lamda
        self.num_class = num_class
        self.scale = scale
        self.batch_size = batch_size
        self.feat_dim = features_dim
        # store the center of each class , should be ( num_class, features_dim)
        self.feature_centers = nn.Parameter(torch.randn([num_class, features_dim]))
        # self.lossfunc = CenterLossFunc.apply

    def forward(self, output_features, y_truth):
        """
        损失计算
        :param output_features: conv层输出的特征,  [b,c,h,w]
        :param y_truth:  标签值  [b,]
        :return:
        """
        batch_size = y_truth.size(0)
        output_features = output_features.view(batch_size, -1)
        assert output_features.size(-1) == self.feat_dim
        factor = self.scale / batch_size
        # return self.lamda * factor * self.lossfunc(output_features, y_truth, self.feature_centers))

        centers_batch = self.feature_centers.index_select(0, y_truth.long())  # [b,features_dim]
        diff = output_features - centers_batch
        loss = self.lamda * 0.5 * factor * (diff.pow(2).sum())
        #########
        return loss


class CenterLossFunc(Function):
    # https://blog.csdn.net/xiewenbo/article/details/89286462
    @staticmethod
    def forward(ctx, feat, labels, centers):
        ctx.save_for_backward(feat, labels, centers)
        centers_batch = centers.index_select(0, labels.long())
        diff = feat - centers_batch
        return diff.pow(2).sum() / 2.0

    @staticmethod
    def backward(ctx, grad_output):
        # grad_output 是最外层的梯度, 一般=1.0
        feature, label, centers, superparams = ctx.saved_tensors
        batch_size = label.size(0)
        # 记录下想相同类别的索引, 求梯度时使用
        label_occur = dict()
        for i, label_v in enumerate(label.cpu().numpy()):
            label_occur.setdefault(int(label_v), []).append(i)

        delta_center = torch.zeros_like(centers).cuda()
        centers_batch = centers.index_select(0, label.long())
        diff = feature - centers_batch

        # 存储per class 的diff 总和
        grad_class_sum = torch.zeros([1, centers.size(-1)]).cuda()
        for label_v, sample_index in label_occur.items():
            grad_class_sum[:] = 0
            for i in sample_index:
                grad_class_sum += diff[i]
            # 求per class的梯度均值
            delta_center[label_v] = -1 * grad_class_sum / (1 + len(sample_index))

        ## forced update center, 由opt执行
        # centers -= alpha * grad_output * delta_center

        # backward输入参数和forward输出参数必须一一对应
        grad_center = grad_output * delta_center
        grad_feat = grad_output * diff
        grad_label = None
        return grad_feat, grad_label, grad_center


class Loss1(nn.Module):
    def __init__(self):
        super(Loss1, self).__init__()
        self.lossfunc = LossFunc.apply

    def forward(self, pred, truth):
        # return torch.abs(pred - truth)
        return self.lossfunc(pred, truth)


class LossFunc(Function):
    @staticmethod
    def forward(ctx, pred, truth):
        loss = torch.abs(pred - truth)
        ctx.save_for_backward(pred, truth)
        return loss

    @staticmethod
    def backward(ctx, grad_output):
        pred, truth = ctx.saved_tensors
        print(f'grad_output={grad_output}')
        return grad_output, None


class Loss2(nn.Module):
    def __init__(self):
        super(Loss2, self).__init__()

    def forward(self, pred, truth):
        return torch.abs(pred)


if __name__ == '__main__':
    # test 1
    import random

    ct = CenterLoss(2, 10, 0.1, 1., batch_size=10)
    y = torch.Tensor([8., 3., 8., 5., 3., 0., 6., 5., 2., 3.])
    # y = torch.Tensor([random.choice(range(10)) for i in range(10)])
    feat = torch.zeros(10, 2).requires_grad_()
    out = ct(feat, y)
    print(f'forward loss = {out.item()}')
    out.backward()
    print(feat.grad)
    print(ct.feature_centers.grad)

    # # test2
    # x = torch.Tensor([3.]).requires_grad_()
    # w = torch.nn.Parameter(torch.Tensor([2.]))
    # y = 2 * ((5 - w * x) ** 2)
    # ct = Loss1()
    # out = ct(y, torch.Tensor([10.]))
    # print(out.item())
    # out.backward()
    # print(x.grad)
    # print(w.grad)

    # # test3
    # x = torch.Tensor([3.]).requires_grad_()
    # y = 2 * ((5 - x) ** 2)
    # ct = Loss2()
    # out = ct(y, 10)
    # print(out.item())
    # out.backward()
    # print(out.grad)
    # print(x.grad)

 

posted @ 2019-12-19 23:28  dangxusheng  阅读(3462)  评论(0编辑  收藏  举报