BAdam: A Memory Efficient Full Parameter Optimization Method for Large Language Models

Luo Q., Yu H. and Li X. BAdam: A memory efficient full parameter optimization method for large language models. arXiv preprint, 2024.

本文介绍了一种 Block corrdinate descent (BCD) 的训练方式.

BAdam

  • 当模型本身很大的时候, 训练它会成为一个很大的问题, 所以现在会流行一些 LoRA 等低质方法用于更快速地更新模型.

  • 这个问题其实很大程度上是因为常用地 Adam 至少需要缓存 2x 模型的量, 所以本文提出的 BAdam 就是希望能够每次仅更新其中的一个 block.

  • BAdam 将整个模型分成 \(D\) 份: \(\pi_1, \ldots, \pi_i, \ldots, \pi_D\), 并假设其中的参数为 \(\theta = \{\theta_{\pi_1}, \ldots, \theta_{\pi_i}, \ldots, \theta_{\pi_D}\}\).

  • 每一次那个更新, 仅更新其中的某一个 block:

    \[\theta_{\pi_i}^{t+1} \in \mathop{\text{argmin}} \limits_{\theta_{\pi_i} \in \mathbb{R}^{d_i}} \frac{1}{n} \sum_{j=1}^n \ell_j ( \theta_{\pi_1}^{t+1}, \ldots, \theta_{\pi_{i-1}}^{t+1}, \theta_{\pi_i}, \theta_{\pi_{i+1}}^t, \ldots \theta_{\pi_{D}}^t ). \]

  • 具体的算法如下, 注意到, 对每个 block 会更新 K 次:

代码

[official]

posted @ 2024-08-27 10:12  馒头and花卷  阅读(65)  评论(0编辑  收藏  举报