Pytorch-区分nn.BCELoss()、nn.BCEWithLogitsLoss()和nn.CrossEntropyLoss() 的用法
详细理论部分可参考https://www.cnblogs.com/wanghui-garcia/p/10862733.html
BCELoss()和BCEWithLogitsLoss()的输出logits和目标labels(必须是one_hot形式)的形状相同。
CrossEntropyLoss()的目标labels的形状是[3, 1](以下面为例,不能是one_hot形式),输出logits是[3, 2]。如果是多分类,labels的形状是[batch, 1],值为0~num_classes-1之间。
1 import torch 2 import torch.nn as nn 3 import torch.nn.functional as F 4 5 m = nn.Sigmoid() 6 7 loss_f1 = nn.BCELoss() 8 loss_f2 = nn.BCEWithLogitsLoss() 9 loss_f3 = nn.CrossEntropyLoss() 10 11 logits = torch.randn(3, 2) 12 labels = torch.FloatTensor([[0, 1], [1, 0], [1, 0]]) 13 14 print(loss_f1(m(logits), labels)) #tensor(0.9314),注意logits先被激活函数作用 15 print(loss_f2(logits, labels)) #tensor(0.9314) 16 17 label2 = torch.LongTensor([1, 0, 0]) 18 print(loss_f3(logits, label2)) #tensor(1.2842) 19 20 logits3 = torch.randn(3, 10) #如果是十分类 21 label3 = torch.LongTensor([9, 2, 5]) 22 print(loss_f3(logits3, label3)) #tensor(2.6467)
如果label2也想变成labels,然后通过BCELoss进行计算的话,可以先转变成独热编码的形式:
1 encode = F.one_hot(label2, num_classes = 2) #encode的值和labels一样,但是类型是LongTensor 2 print(loss_f1(m(logits), encode.type(torch.float32))) #tensor(0.9314)
Tip:BCEWithLogitsLoss()可以用于多标签分类,将最后分类层的每个输出节点使用sigmoid激活函数激活,然后对每个输出节点和对应的标签计算交叉熵损失函数。
1 import torch 2 import numpy as np 3 4 pred = np.array([[-0.4089, -1.2471, 0.5907], 5 [-0.4897, -0.8267, -0.7349], 6 [0.5241, -0.1246, -0.4751]]) 7 label = np.array([[0, 1, 1], 8 [0, 0, 1], 9 [1, 0, 1]]) 10 11 pred = torch.from_numpy(pred).float() 12 label = torch.from_numpy(label).float() 13 14 crition1 = torch.nn.BCEWithLogitsLoss() 15 loss1 = crition1(pred, label) 16 print(loss1) #tensor(0.7193) 17 18 crition2 = torch.nn.MultiLabelSoftMarginLoss() 19 loss2 = crition2(pred, label) 20 print(loss2) #tensor(0.7193)
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· .NET Core 中如何实现缓存的预热?
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· 阿里巴巴 QwQ-32B真的超越了 DeepSeek R-1吗?
· 【译】Visual Studio 中新的强大生产力特性
· 10年+ .NET Coder 心语 ── 封装的思维:从隐藏、稳定开始理解其本质意义
· 【设计模式】告别冗长if-else语句:使用策略模式优化代码结构