pytorch版的labelsmooth分类损失函数

import torch
import torch.nn as nn
import torch.nn.functional as F


class LSR(nn.Module):
    def __init__(self, n_classes=10, eps=0.1):
        super(LSR, self).__init__()
        self.n_classes = n_classes
        self.eps = eps

    def forward(self, outputs, labels):
        # labels.shape: [b,]
        assert outputs.size(0) == labels.size(0)
        n_classes = self.n_classes
        one_hot = F.one_hot(labels, n_classes).float()
        mask = ~(one_hot > 0)
        smooth_labels = torch.masked_fill(one_hot, mask, eps / (n_classes - 1))
        smooth_labels = torch.masked_fill(smooth_labels, ~mask, 1 - eps)
        ce_loss = torch.sum(-smooth_labels * F.log_softmax(outputs, 1), dim=1).mean()
        # ce_loss = F.nll_loss(F.log_softmax(outputs, 1), labels, reduction='mean')
        return ce_loss


if __name__ == '__main__':
    labels = [0, 1, 2, 1, 1, 0, 3]
    labels = torch.tensor(labels)
    eps = 0.1
    n_classes = 4
    outputs = torch.rand([7, 4])
    print(outputs)

    LL = LSR(n_classes, eps)
    LL2 = nn.CrossEntropyLoss()
    loss = LL.forward(outputs, labels)
    loss2 = LL2.forward(outputs, labels)
    print(loss)
    print(loss2)

 

posted @ 2021-02-06 18:42  dangxusheng  阅读(1696)  评论(0编辑  收藏  举报