使用FP8加速PyTorch训练的两种方法总结
在PyTorch中,FP8(8-bit 浮点数)是一个较新的数据类型,用于实现高效的神经网络训练和推理。它主要被设计来降低模型运行时的内存占用,并加快计算速度,同时尽量保持训练和推理的准确性。虽然PyTorch官方在标准发布中尚未全面支持FP8,但是在2.2版本中PyTorch已经包含了对FP8的“有限支持”并且出现了2个新的变量类型,
torch.float8_e4m3fn
和
torch.float8_e5m2
,而H100也支持这种类型,所以这篇文章我们就来介绍如何使用FP8来提高训练效率
https://avoid.overfit.cn/post/0dd1fba546674b48b932260fa8742971