Pytorch-区分nn.BCELoss()、nn.BCEWithLogitsLoss()和nn.CrossEntropyLoss() 的用法

详细理论部分可参考https://www.cnblogs.com/wanghui-garcia/p/10862733.html

BCELoss()和BCEWithLogitsLoss()的输出logits和目标labels(必须是one_hot形式)的形状相同

CrossEntropyLoss()的目标labels的形状是[3, 1](以下面为例,不能是one_hot形式),输出logits是[3, 2]。如果是多分类,labels的形状是[batch, 1],值为0~num_classes-1之间。

复制代码
 1 import torch
 2 import torch.nn as nn
 3 import torch.nn.functional as F
 4 
 5 m = nn.Sigmoid()
 6 
 7 loss_f1 = nn.BCELoss()
 8 loss_f2 = nn.BCEWithLogitsLoss()
 9 loss_f3 = nn.CrossEntropyLoss()
10 
11 logits = torch.randn(3, 2)
12 labels = torch.FloatTensor([[0, 1], [1, 0], [1, 0]])
13 
14 print(loss_f1(m(logits), labels))      #tensor(0.9314),注意logits先被激活函数作用
15 print(loss_f2(logits, labels))         #tensor(0.9314)
16 
17 label2 = torch.LongTensor([1, 0, 0])
18 print(loss_f3(logits, label2))         #tensor(1.2842)
19 
20 logits3 = torch.randn(3, 10)       #如果是十分类 
21 label3 = torch.LongTensor([9, 2, 5])
22 print(loss_f3(logits3, label3))        #tensor(2.6467)
复制代码

 如果label2也想变成labels,然后通过BCELoss进行计算的话,可以先转变成独热编码的形式:

1 encode = F.one_hot(label2, num_classes = 2)               #encode的值和labels一样,但是类型是LongTensor
2 print(loss_f1(m(logits), encode.type(torch.float32)))     #tensor(0.9314)

Tip:BCEWithLogitsLoss()可以用于多标签分类,将最后分类层的每个输出节点使用sigmoid激活函数激活,然后对每个输出节点和对应的标签计算交叉熵损失函数。

复制代码
 1 import torch
 2 import numpy as np
 3 
 4 pred = np.array([[-0.4089, -1.2471, 0.5907],
 5                 [-0.4897, -0.8267, -0.7349],
 6                 [0.5241, -0.1246, -0.4751]])
 7 label = np.array([[0, 1, 1],
 8                   [0, 0, 1],
 9                   [1, 0, 1]])
10 
11 pred = torch.from_numpy(pred).float()
12 label = torch.from_numpy(label).float()
13 
14 crition1 = torch.nn.BCEWithLogitsLoss()
15 loss1 = crition1(pred, label)
16 print(loss1)              #tensor(0.7193)
17 
18 crition2 = torch.nn.MultiLabelSoftMarginLoss()
19 loss2 = crition2(pred, label)
20 print(loss2)              #tensor(0.7193)
复制代码

 

posted @   最咸的鱼  阅读(6243)  评论(0编辑  收藏  举报
编辑推荐:
· .NET Core 中如何实现缓存的预热?
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
阅读排行:
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· 阿里巴巴 QwQ-32B真的超越了 DeepSeek R-1吗?
· 【译】Visual Studio 中新的强大生产力特性
· 10年+ .NET Coder 心语 ── 封装的思维:从隐藏、稳定开始理解其本质意义
· 【设计模式】告别冗长if-else语句:使用策略模式优化代码结构
点击右上角即可分享
微信分享提示