hinge loss/支持向量损失的理解

https://blog.csdn.net/AI_focus/article/details/78339234

https://www.cnblogs.com/massquantity/p/8964029.html

pytprch HingeLoss 的实现:

复制代码
"""
    铰链损失
    SVM hinge loss, 等价于 torch.nn.MultiMarginLoss
    hinge loss = sum(max(0,pred-true+1)) / batch_size (when y_hat != gt)
    square hinge loss = sum(max(0,pred-true+1)^2) / batch_size (when y_hat != gt).
    参考博客: https://blog.csdn.net/AI_focus/article/details/78339234
    参考pytorch - torch.nn.MultiMarginLoss
    torch.nn.MultiMarginLoss 和 SVM hingeLoss
    https://pytorch.org/docs/1.6.0/generated/torch.nn.MultiMarginLoss.html?highlight=marginloss#torch.nn.MultiMarginLoss
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class HingeLoss(nn.Module):
    def __init__(self, n_classes=10, margin=1.):
        super(HingeLoss, self).__init__()
        self.n_classes = n_classes
        self.margin = margin

    def forward(self, outputs, labels):
        # labels.shape: [b,]
        assert outputs.size(0) == labels.size(0)
        one_hot = F.one_hot(labels, self.n_classes).float()
        idx = list(range(labels.size(0)))
        # 获取标签位置的模型输出
        target_loc_outputs = outputs[idx, list(labels)].unsqueeze(1)
        mask = (one_hot == 0)
        # 获取非标签位置的模型输出值
        other_loc_outputs = outputs[mask].reshape([labels.size(0), -1])
        loss = torch.max(torch.zeros_like(other_loc_outputs),
                         self.margin - target_loc_outputs + other_loc_outputs).pow(2) \
                   .sum(dim=1) / self.n_classes
        return loss.mean()


if __name__ == '__main__':
    labels = [0, 1, 2, 1, 1, 0, 3]
    idx = list(range(len(labels)))
    n_classes = 4
    labels = torch.tensor(labels)
    outputs = torch.randn([7, n_classes])

    ll = HingeLoss(n_classes, 0.5)
    loss = ll(outputs, labels)
    print(loss)

    l2 = nn.MultiMarginLoss(p=2, margin=0.5, reduction='mean')
    loss2 = l2.forward(outputs, labels)
    print(loss2)
复制代码

 

  

posted @   dangxusheng  阅读(1388)  评论(0编辑  收藏  举报
编辑推荐:
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
点击右上角即可分享
微信分享提示