pytorch中网络参数的默认精度

pytorch默认使用单精度float32训练模型,其主要原因为:使用float16训练模型,模型效果会有损失,而使用double(float64)会有2倍的内存压力,且不会带来太多的精度提升,因此默认使用单精度float32训练模型。

 

由于输入类型不一致导致报错:

PyTorch:expected scalar type Float but found Double

表明代码中网络参数类型不统一。

pytorch如何更改默认单精度float32训练模型,而改为torch.float64对模型进行训练呢?

解决办法:把模型的权重参数数据类型和输入数据类型全部设置为torch.float64。

使用torch.set_default_dtype(torch.float64)把模型参数转化为float64,或使用net = net.double()

输入类型使用tensor.type(torch.float64)将输入数据类型转换为torch.float64。

 

 

posted on   那抹阳光1994  阅读(3657)  评论(0编辑  收藏  举报

相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
历史上的今天:
2020-06-27 记MongoDB的安装

导航

< 2025年3月 >
23 24 25 26 27 28 1
2 3 4 5 6 7 8
9 10 11 12 13 14 15
16 17 18 19 20 21 22
23 24 25 26 27 28 29
30 31 1 2 3 4 5
点击右上角即可分享
微信分享提示