使用Unit Scaling进行FP16 和 FP8 训练
Unit Scaling 是一种新的低精度机器学习方法,能够在没有损失缩放的情况下训练 FP16 和 FP8 中的语言模型。
使用FP16和BFLOAT16替代FP32可以将内存、带宽和计算需求的大幅减少,这也是目前越来越大的模型所需要的。
背景介绍
随着支持fp8的硬件的发展,在不影响效率的前提下,进一步降低精度也成为了可能。但是这些较小的、低精度的格式在实践中并不总是易于使用。对于FP8来说则更加困难。因为这些较小的格式通常将用户限制在更窄的可表示值范围内。为了解决这个问题,Graphcore Research开发了一种新方法,我们称之为Unit Scaling。
上图为FP16和FP8中量化的不同尺度的正态分布的信噪比(SNR)。对于较小的数字格式,信号在较窄的尺度范围内较强。
Unit Scaling是一种模型设计技术,它在初始化时根据缩放原则进行操作:也就是说对激活、权重和梯度的单位方差进行缩放。模型会自动生成针对低精度数字格式进行良好缩放的张量。并且使用更简单,并最大限度地减少这些表示的缺点,与低精度训练的替代方法不同,它引入的开销和额外的复杂性很小。
论文的方法取得了突破性的成果:首次在 FP16 甚至 FP8 中准确地训练了 BERT Base 和 BERT Large 模型,并且没有缩放的性能损失。模型也不需要额外的超参数,可以直接使用。
对于关心结果并因此希望在 FP16 和 FP8 中进行训练的人来说,Unit Scaling提供了一个直接的解决方案。
完整文章:
https://avoid.overfit.cn/post/dfcaa9c45d70421a98f4df52a9e83610