pytorch的nn.MSELoss损失函数

MSE是mean squared error的缩写,即平均平方误差,简称均方误差。

MSE是逐元素计算的,计算公式为:

旧版的nn.MSELoss()函数有reduce、size_average两个参数,新版的只有一个reduction参数了,功能是一样的。reduction的意思是维度要不要缩减,以及怎么缩减,有三个选项:

  • 'none': no reduction will be applied.
  • 'mean': the sum of the output will be divided by the number of elements in the output.
  • 'sum': the output will be summed.

如果不设置reduction参数,默认是'mean'。

程序示例: 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
import torch.nn as nn
 
a = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
b = torch.tensor([[3, 5], [8, 6]], dtype=torch.float32)
 
loss_fn1 = torch.nn.MSELoss(reduction='none')
loss1 = loss_fn1(a, b)
print(loss1)   # 输出结果:tensor([[ 4.,  9.],
               #                 [25.,  4.]])
 
loss_fn2 = torch.nn.MSELoss(reduction='sum')
loss2 = loss_fn2(a, b)
print(loss2)   # 输出结果:tensor(42.)
 
 
loss_fn3 = torch.nn.MSELoss(reduction='mean')
loss3 = loss_fn3(a, b)
print(loss3)   # 输出结果:tensor(10.5000)

 

对于三维的输入也是一样的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
a = torch.randint(0, 9, (2, 2, 3)).float()
b = torch.randint(0, 9, (2, 2, 3)).float()
print('a:\n', a)
print('b:\n', b)
 
loss_fn1 = torch.nn.MSELoss(reduction='none')
loss1 = loss_fn1(a, b)
print('loss_none:\n', loss1)
 
loss_fn2 = torch.nn.MSELoss(reduction='sum')
loss2 = loss_fn2(a, b)
print('loss_sum:\n', loss2)
 
 
loss_fn3 = torch.nn.MSELoss(reduction='mean')
loss3 = loss_fn3(a, b)
print('loss_mean:\n', loss3)

运行结果:

 

 参考资料:

pytorch的nn.MSELoss损失函数

 

posted @   Picassooo  阅读(39995)  评论(0编辑  收藏  举报
编辑推荐:
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· DeepSeek 开源周回顾「GitHub 热点速览」
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
点击右上角即可分享
微信分享提示