Pytorch中的交叉熵CE和均方误差MSE分别是如何计算的?

本文主要关注输入输出的形状,通过两种标签形式探索一下其计算机制。

交叉熵损失函数

实验设置:假设采用AutoEncoder+分类器。AE负责重构图像,计算MSE。分类器通过latent vector计算23个类别的概率向量。

    import torch nn as nn
    net = AutoEncoder(num_classes=23)
    import torch.nn.functional as F
    loss_fn = nn.CrossEntropyLoss()
    x_spec = torch.rand(64, 128, 313)  # 假设输入四维张量,单通道
    x_label = torch.randint(0, 23, size=(64,))  # 23个类别
    pred, feature = net(x_wav, x_spec, x_label)
    print(pred.shape)  # (64, 23)
    print(x_label.shape) # (64,)
    print(feature.shape)  # (64, 1, 128, 313)

1. 分别用 torch.nn.CrossEntropy()和F.crossentropy(),通过形状为(64,)的实数标签计算

2. 分别用torch.nn.CrossEntropy()和F.crossentropy(),把(64,)的实数标签转换为形状为(64, 23)的One-Hot标签计算

    print(loss_fn(pred, x_label))
    print(F.cross_entropy(pred, x_label, reduction="none").shape)

    one_hot = torch.zeros(pred.shape, device=pred.device)
    one_hot = one_hot.scatter_(1, x_label.unsqueeze(1).long(), 1)
    print(loss_fn(pred, one_hot))
    print(F.cross_entropy(pred, one_hot, reduction="none").shape)

输出如下,最重要的是,如果要保留不同样本的loss,就不应该用nn.CrossEntropy()而是F.crossentropy(, reduction="none"):

tensor(3.1840, grad_fn=<NllLossBackward0>)
torch.Size([64])
tensor(3.1840, grad_fn=<DivBackward1>)
torch.Size([64])

没有区别,实际两者殊途同归,因为CE的计算也有两种方式:

1. 取出预测概率向量的指定维度数值,和groundtruth(真实标签)对比

2. 把真实标签转换为one-hot向量,然后和预测概率向量计算交叉熵,由于除了自身类别之外的维度都是0,因此没有区别。

 

MSE损失函数

这个就大不相同了,nn.MSELoss()对两个样本计算后返回一个标量数值,

但是通过F.mse_loss(input_x, recon_loss, reduction="none")返回的是一个和样本同样形状的张量,需要再通过loss = loss.mean(axis=3).mean(axis=2).mean(axis=1)来规约为(64,)形状的向量,否则所有样本的损失之都会被规约。不过,假如没有reduction="none"参数的话,返回值则是一个标量,已经被规约完了。

    loss_fn = nn.MSELoss()
    print(loss_fn(x_spec, feature))
    print(F.mse_loss(x_spec, feature))
    print(F.mse_loss(x_spec, feature, reduction="none").shape)
    print(F.mse_loss(x_spec, feature, reduction="none").mean(axis=3).mean(axis=2).mean(axis=1).shape)

输出内容如下,由此可以看出MSELoss的区别,最重要的是,如果要保留每一个样本单独的loss值,就不应该用nn.MSELoss而是F.mse_loss(, reduction="none"):

tensor(1.4745, grad_fn=<MseLossBackward0>)
tensor(1.4745, grad_fn=<MseLossBackward0>)
torch.Size([64, 1, 288, 128])
torch.Size([64])

 

posted @ 2024-01-12 14:53  倦鸟已归时  阅读(93)  评论(0编辑  收藏  举报