pytorch的nn.MSELoss损失函数
MSE是mean squared error的缩写,即平均平方误差,简称均方误差。
MSE是逐元素计算的,计算公式为:
旧版的nn.MSELoss()函数有reduce、size_average两个参数,新版的只有一个reduction参数了,功能是一样的。reduction的意思是维度要不要缩减,以及怎么缩减,有三个选项:
- 'none': no reduction will be applied.
- 'mean': the sum of the output will be divided by the number of elements in the output.
- 'sum': the output will be summed.
如果不设置reduction参数,默认是'mean'。
程序示例:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | import torch import torch.nn as nn a = torch.tensor([[ 1 , 2 ], [ 3 , 4 ]], dtype = torch.float32) b = torch.tensor([[ 3 , 5 ], [ 8 , 6 ]], dtype = torch.float32) loss_fn1 = torch.nn.MSELoss(reduction = 'none' ) loss1 = loss_fn1(a, b) print (loss1) # 输出结果:tensor([[ 4., 9.], # [25., 4.]]) loss_fn2 = torch.nn.MSELoss(reduction = 'sum' ) loss2 = loss_fn2(a, b) print (loss2) # 输出结果:tensor(42.) loss_fn3 = torch.nn.MSELoss(reduction = 'mean' ) loss3 = loss_fn3(a, b) print (loss3) # 输出结果:tensor(10.5000) |
对于三维的输入也是一样的:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | a = torch.randint( 0 , 9 , ( 2 , 2 , 3 )). float () b = torch.randint( 0 , 9 , ( 2 , 2 , 3 )). float () print ( 'a:\n' , a) print ( 'b:\n' , b) loss_fn1 = torch.nn.MSELoss(reduction = 'none' ) loss1 = loss_fn1(a, b) print ( 'loss_none:\n' , loss1) loss_fn2 = torch.nn.MSELoss(reduction = 'sum' ) loss2 = loss_fn2(a, b) print ( 'loss_sum:\n' , loss2) loss_fn3 = torch.nn.MSELoss(reduction = 'mean' ) loss3 = loss_fn3(a, b) print ( 'loss_mean:\n' , loss3) |
运行结果:
参考资料:
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· DeepSeek 开源周回顾「GitHub 热点速览」
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了