pytorch版的labelsmooth分类损失函数

复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class LSR(nn.Module):
    def __init__(self, n_classes=10, eps=0.1):
        super(LSR, self).__init__()
        self.n_classes = n_classes
        self.eps = eps

    def forward(self, outputs, labels):
        # labels.shape: [b,]
        assert outputs.size(0) == labels.size(0)
        n_classes = self.n_classes
        one_hot = F.one_hot(labels, n_classes).float()
        mask = ~(one_hot > 0)
        smooth_labels = torch.masked_fill(one_hot, mask, eps / (n_classes - 1))
        smooth_labels = torch.masked_fill(smooth_labels, ~mask, 1 - eps)
        ce_loss = torch.sum(-smooth_labels * F.log_softmax(outputs, 1), dim=1).mean()
        # ce_loss = F.nll_loss(F.log_softmax(outputs, 1), labels, reduction='mean')
        return ce_loss


if __name__ == '__main__':
    labels = [0, 1, 2, 1, 1, 0, 3]
    labels = torch.tensor(labels)
    eps = 0.1
    n_classes = 4
    outputs = torch.rand([7, 4])
    print(outputs)

    LL = LSR(n_classes, eps)
    LL2 = nn.CrossEntropyLoss()
    loss = LL.forward(outputs, labels)
    loss2 = LL2.forward(outputs, labels)
    print(loss)
    print(loss2)
复制代码

 

posted @   dangxusheng  阅读(1725)  评论(0编辑  收藏  举报
编辑推荐:
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
点击右上角即可分享
微信分享提示