pytorch的nn.MSELoss损失函数
MSE是mean squared error的缩写,即平均平方误差,简称均方误差。
MSE是逐元素计算的,计算公式为:
旧版的nn.MSELoss()函数有reduce、size_average两个参数,新版的只有一个reduction参数了,功能是一样的。reduction的意思是维度要不要缩减,以及怎么缩减,有三个选项:
- 'none': no reduction will be applied.
- 'mean': the sum of the output will be divided by the number of elements in the output.
- 'sum': the output will be summed.
如果不设置reduction参数,默认是'mean'。
程序示例:
import torch import torch.nn as nn a = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) b = torch.tensor([[3, 5], [8, 6]], dtype=torch.float32) loss_fn1 = torch.nn.MSELoss(reduction='none') loss1 = loss_fn1(a, b) print(loss1) # 输出结果:tensor([[ 4., 9.], # [25., 4.]]) loss_fn2 = torch.nn.MSELoss(reduction='sum') loss2 = loss_fn2(a, b) print(loss2) # 输出结果:tensor(42.) loss_fn3 = torch.nn.MSELoss(reduction='mean') loss3 = loss_fn3(a, b) print(loss3) # 输出结果:tensor(10.5000)
对于三维的输入也是一样的:
a = torch.randint(0, 9, (2, 2, 3)).float() b = torch.randint(0, 9, (2, 2, 3)).float() print('a:\n', a) print('b:\n', b) loss_fn1 = torch.nn.MSELoss(reduction='none') loss1 = loss_fn1(a, b) print('loss_none:\n', loss1) loss_fn2 = torch.nn.MSELoss(reduction='sum') loss2 = loss_fn2(a, b) print('loss_sum:\n', loss2) loss_fn3 = torch.nn.MSELoss(reduction='mean') loss3 = loss_fn3(a, b) print('loss_mean:\n', loss3)
运行结果:
参考资料: