LLM 推理和训练占用显存

https://blog.csdn.net/weixin_44292902/article/details/133767448

https://www.53ai.com/news/finetuning/2024083051493.html 推荐, 讲解训练和推理时的显存占用,lora 和 qlora。

如果模型参数量为 X (fp16), 推理一般占用 2X (模型参数+各种激活值,beamsearch),混合精度训练一般占用 8X。

数据精度

我们都知道:

  • 1 byte = 8 bits
  • 1 KB = 1,024 bytes
  • 1 MB = 1,024 KB
  • 1 GB = 1,024 MB

由此可以明白,一个含有 1G 参数的模型,如果每一个参数都是 32bit(4byte),那么直接加载模型就会占用 4x1G 的显存。

(1)常见的几种精度类型

个人认为只需掌握下图几个常见的数据类型就好,对于更多的精度类型都是可以做到触类旁通发,图源英伟达安培架构白皮书:

img

各种精度的数据结构

可以非常直观地看到,浮点数主要是由符号位(sign)、指数位(exponent)和小数位(mantissa)三部分组成。其实这也类似

符号位都是 1 位(0 表示正,1 表示负),指数位影响浮点数范围,小数位影响精度。

其中 TF32 并不是有 32bit,只有 19bit 不要记错了。BF16 指的是 Brain Float 16,由 Google Brain 团队提出。

(2)具体计算例子

我说实话,讲太多不如一个形象的图片或者例子来得直接,下面我们将通过一个例子来深入理解如何通过这三个部分来得到我们最终的数据。

我以 BF16,如今业界用的最广泛的精度类型来举个栗子,下面的数完全是我用克劳德大哥随机画的:

题目:

先给出具体计算公式:
$$
\text{Value} = (-1)^{\text{sign}} \times 2^{\text{exponent} - 127} \times 1.\text{mantissa}
$$
其中:

- sign:符号位,0 表示正数,1 表示负数。

- exponent:指数值,偏移量为 127。因为指数位为 0~255,减去 127 是为了表示正负。

- mantissa:尾数部分的二进制值。

假设我们有一个 BF16 的二进制数:

0 10000011 0100000

步骤 1:解析各部分

  • 符号位(sign):第 1 位是 0,表示正数。
  • 指数位(exponent):接下来的 8 位是 10000011,转换为十进制是 131
  • 尾数位(mantissa):最后的 7 位是 0100000,表示的小数是 0.25(计算方式见下面)。

步骤 2:计算指数部分

指数值的实际值需要减去偏移量 127:
[
\text{实际指数} = 131 - 127 = 4
]

步骤 3:计算尾数部分(mantissa)

尾数部分的 7 位为 0100000,我们需要将它转换为小数。BF16 的尾数部分采用隐式小数的表示法,即默认存在一个隐含的 1,所以它的数值表示为:
[
1 + 0.25 = 1.25
]
其中,0.25 是通过尾数的位权重计算得到的(最高位对应 (2^{-1}),接下来依次为 (2{-2})、(2) 等)。

步骤 4:计算最终数值

将所有部分代入公式:
[
\text{Value} = (-1)^0 \times 2^4 \times 1.25 = 1 \times 16 \times 1.25 = 20
]

最终结果

这个 BF16 的二进制数 0 10000011 0100000 对应的十进制数是 20

推理

图片

FP32 需要占用 4 Byte 空间,7B 的模型,存储参数需要 7Billion * 4Byte ~= 7G *4 = 28G,如果是 FP16 则为 14 G, 如果为 int8 则为 7G, int4 则为 3.5G。但是性能也会有损失:

Model 5-shot C-Eval MMLU CMMLU
Baichuan-13B-Base 52.4 51.6 55.3
Baichuan-13B-Base-int8 51.2 49.9 54.5
Baichuan-13B-Base-int4 47.6 46.0 51.0

注意上面只是加载模型到显存,模型运算时的一些临时变量也需要申请空间,比如你beam search的时候。所以真正做推理的时候记得留一些Buffer,不然就容易OOM。

训练

图片

训练主要有模型权重、优化器、梯度、激活值 几个部分。

  1. 模型权重(Model Parameters)fp16存储:X

  2. 优化器(Optimizer)

不同的优化器对显存的占用不同:

  • AdamW:使用 fp32,每个参数需要占用8个字节,因为需要维护两个状态变量(动量和二阶矩估计),以及fp32的模型权重。因此,优化器的显存占用是全精度(float32)模型权重的3倍。3*2X

  • bitsandbytes优化的AdamW:每个参数需要占用2个字节,因此优化器的显存占用是全精度模型权重的一半。

  • SGD:优化器的显存占用和全精度模型权重一样。

    [ \text{显存占用} = \text{模型权重显存占用} ]

  1. 梯度(Gradients)

梯度的显存占用与模型权重的显存占用相同,因为每个参数都有一个对应的梯度,可以用 fp16存储。X

  1. 计算图内部变量(Forward Activations)

计算图内部变量在前向传播和反向传播过程中需要存储,因此也需要占用显存。这部分显存占用与模型的具体实现和批量大小(batch size)成正比。批量大小越大,这部分显存占用越多。同样的结论也适用于序列长度(sequence length)。

例如,以下代码:

y = self.net_a(x)
z = self.net_b(y)

在这种情况下,中间变量 xy 都需要存储。但是,如果写成下面这样,y 就不需要存储了:

z = self.net_b(self.net_a(x))

理论上,一个 net block 可以完全用函数包裹起来,不使用中间变量。下一代计算框架可能会采用函数式编程语言来优化这部分显存占用。