负对数似然(NLL)和困惑度(PPL)
让我们通过一个简单的例子来演示这段代码的计算过程,包括负对数似然(NLL)和困惑度(PPL)的计算。为了简化,我们将假设一个非常小的模型输出和数据。
假设:
- 我们有两个样本(即 batch size 为 2)。
- 每个样本有 3 个可能的类别,
S_logits
是模型输出的 logits。 smask
是一个掩码,假设全部为True
,即我们对所有样本和所有类别都进行处理。s_batch_id
是一个表示每个样本的索引的向量,用于scatter_mean
的计算。
1. 模型输出的 logits:
假设 r_pred_S_logits
的最后一层输出如下(为了简单,假设只有一个时间步长):
import torch
# 假设的 logits
r_pred_S_logits = [torch.tensor([[[2.0, 1.0, 0.1], [1.0, 2.5, 0.5]]])]
# 掩码
smask = torch.tensor([True, True])
# 批次 ID(假设第一个样本和第二个样本)
s_batch_id = torch.tensor([0, 1])
2. 计算 softmax 概率分布:
首先,对 S_logits
进行 softmax 操作:
S_logits = r_pred_S_logits[-1][0][smask] # shape: (2, 3)
S_dists = torch.softmax(S_logits, dim=-1) # shape: (2, 3)
print(S_dists)
这将输出:
tensor([[0.6590, 0.2424, 0.0986],
[0.2312, 0.6285, 0.1403]])
每一行是一个样本的概率分布。
3. 采样类别:
然后,从 S_dists
中使用 torch.multinomial
采样类别:
pred_S = torch.zeros_like(smask, dtype=torch.long)
pred_S[smask] = torch.multinomial(S_dists, num_samples=1).squeeze()
print(pred_S)
假设采样结果为:
tensor([0, 1])
这意味着第一个样本预测为类别 0,第二个样本预测为类别 1。
4. 计算 NLL:
我们从 S_dists
中提取出预测类别的概率,并计算负对数似然(NLL):
S_probs = S_dists[torch.arange(s_batch_id.shape[0]), pred_S[smask]]
print(S_probs)
假设输出为:
tensor([0.6590, 0.6285])
计算 NLL:
nlls = -torch.log(S_probs)
print(nlls)
输出:
tensor([0.4170, 0.4642])
这两个值是每个样本的 NLL 值。
5. 计算 PPL:
最后,通过 scatter_mean
计算每个 batch 的平均 NLL(在这个简单的例子中,每个样本有一个唯一的 ID,所以直接取平均值),然后困惑度(PPL)可以通过取指数得到:
from torch_scatter import scatter_mean
ppl = scatter_mean(nlls, s_batch_id)
print(ppl)
结果:
tensor([0.4170, 0.4642])
困惑度可以通过取指数计算得到:
ppl_value = torch.exp(ppl)
print(ppl_value)
结果:
tensor([1.5172, 1.5901])
总结
在这个简单的例子中:
- NLL: 是
0.4170
和0.4642
,分别对应两个样本的负对数似然。 - 困惑度(PPL): 通过
exp(NLL)
计算得出1.5172
和1.5901
。
这些值反映了模型预测的质量,困惑度越小,表示模型对样本的预测越好。
负对数似然是怎么计算的
负对数似然(Negative Log-Likelihood, NLL)是用来衡量模型预测的概率与真实类别之间的差异的一种损失函数。在分类任务中,它可以用于评估模型预测的准确性。NLL 的计算是基于模型输出的概率(通常通过 softmax 函数生成的概率分布)和真实类别标签的。
负对数似然的计算步骤
-
模型输出 logits:
- 模型通常输出 logits,也就是未归一化的分数。这些分数代表了模型对每个类别的信心。
-
Softmax 函数:
- 将 logits 转换为概率分布。这一步骤使用 softmax 函数,使得每个类别的分数被归一化为一个概率值,所有类别的概率之和为 1。
公式如下:
[
P(y_i) = \frac{\exp(\text{logit}i)}{\sum^{N} \exp(\text{logit}_j)}
]其中,( P(y_i) ) 是类别 ( y_i ) 的概率,logit 是模型输出的原始分数。
-
选择预测类别的概率:
- 从 softmax 生成的概率分布中,选择实际发生的类别(或模型预测的类别)的概率。
-
计算负对数似然:
- 取所选概率的负对数作为负对数似然值(NLL)。
公式如下:
[
\text{NLL} = -\log(P(y_{\text{true}}))
]其中,( P(y_{\text{true}}) ) 是真实类别的预测概率。
例子
假设我们有一个三分类问题,模型输出的 logits 为:
import torch
logits = torch.tensor([2.0, 1.0, 0.1])
1. Softmax 计算概率分布:
probs = torch.softmax(logits, dim=-1)
print(probs)
这将输出:
tensor([0.6590, 0.2424, 0.0986])
即,类别 0 的概率是 0.6590,类别 1 的概率是 0.2424,类别 2 的概率是 0.0986。
2. 假设真实类别是 0,那么选择类别 0 的概率:
P_true = probs[0]
print(P_true)
输出:
tensor(0.6590)
3. 计算负对数似然:
nll = -torch.log(P_true)
print(nll)
输出:
tensor(0.4170)
这个值 ( 0.4170 ) 就是类别 0 的负对数似然,它反映了模型对这个类别的预测质量。
总结
- 负对数似然(NLL) 是模型对某个类别预测概率的负对数。
- NLL 越小,说明模型对真实类别的预测概率越高,模型的表现越好。
- NLL 越大,说明模型对真实类别的预测概率越低,模型的表现越差。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律