pytorch中网络参数的默认精度

pytorch默认使用单精度float32训练模型,其主要原因为:使用float16训练模型,模型效果会有损失,而使用double(float64)会有2倍的内存压力,且不会带来太多的精度提升,因此默认使用单精度float32训练模型。

 

由于输入类型不一致导致报错:

PyTorch:expected scalar type Float but found Double

表明代码中网络参数类型不统一。

pytorch如何更改默认单精度float32训练模型,而改为torch.float64对模型进行训练呢?

解决办法:把模型的权重参数数据类型和输入数据类型全部设置为torch.float64。

使用torch.set_default_dtype(torch.float64)把模型参数转化为float64,或使用net = net.double()

输入类型使用tensor.type(torch.float64)将输入数据类型转换为torch.float64。

 

 

posted on 2022-06-27 11:29  那抹阳光1994  阅读(3474)  评论(0编辑  收藏  举报

导航