hinge loss/支持向量损失的理解

https://blog.csdn.net/AI_focus/article/details/78339234

https://www.cnblogs.com/massquantity/p/8964029.html

pytprch HingeLoss 的实现:

"""
    铰链损失
    SVM hinge loss, 等价于 torch.nn.MultiMarginLoss
    hinge loss = sum(max(0,pred-true+1)) / batch_size (when y_hat != gt)
    square hinge loss = sum(max(0,pred-true+1)^2) / batch_size (when y_hat != gt).
    参考博客: https://blog.csdn.net/AI_focus/article/details/78339234
    参考pytorch - torch.nn.MultiMarginLoss
    torch.nn.MultiMarginLoss 和 SVM hingeLoss
    https://pytorch.org/docs/1.6.0/generated/torch.nn.MultiMarginLoss.html?highlight=marginloss#torch.nn.MultiMarginLoss
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class HingeLoss(nn.Module):
    def __init__(self, n_classes=10, margin=1.):
        super(HingeLoss, self).__init__()
        self.n_classes = n_classes
        self.margin = margin

    def forward(self, outputs, labels):
        # labels.shape: [b,]
        assert outputs.size(0) == labels.size(0)
        one_hot = F.one_hot(labels, self.n_classes).float()
        idx = list(range(labels.size(0)))
        # 获取标签位置的模型输出
        target_loc_outputs = outputs[idx, list(labels)].unsqueeze(1)
        mask = (one_hot == 0)
        # 获取非标签位置的模型输出值
        other_loc_outputs = outputs[mask].reshape([labels.size(0), -1])
        loss = torch.max(torch.zeros_like(other_loc_outputs),
                         self.margin - target_loc_outputs + other_loc_outputs).pow(2) \
                   .sum(dim=1) / self.n_classes
        return loss.mean()


if __name__ == '__main__':
    labels = [0, 1, 2, 1, 1, 0, 3]
    idx = list(range(len(labels)))
    n_classes = 4
    labels = torch.tensor(labels)
    outputs = torch.randn([7, n_classes])

    ll = HingeLoss(n_classes, 0.5)
    loss = ll(outputs, labels)
    print(loss)

    l2 = nn.MultiMarginLoss(p=2, margin=0.5, reduction='mean')
    loss2 = l2.forward(outputs, labels)
    print(loss2)

 

  

posted @ 2019-12-19 23:27  dangxusheng  阅读(1385)  评论(0编辑  收藏  举报