GaLore Memory-Efficient LLM Training by Gradient Low-Rank Projection

Zhao J., Zhang Z., Chen B., Wang Z., Anandkumar A. and Tian Y. GaLore: Memory-efficient llm training by gradient low-rank projection. ICML, 2024.

本文提出了一种优化器中高效的缓存策略.

符号说明

  • WtRm×n, 参数;
  • φt, 损失函数;
  • Gt=Wφt(Wt)Rm×n;

GaLore

  • 一般的优化器更新可以归结为:

    Wt+1=WtηG~t,

    其中 G~t=ρt 是对梯度 Gt 进行的一个处理, 在 Adam 中涉及两种动量:

    Mt=β1Mt1+(1β1)Gt,Vt=β2Vt1+(1β2)Gt2,ρt(Gt)=Mt/Vt+ϵ.

  • 像 Adam 这种带 momentum 的, 我们需要缓存 2x 模型大小的量用于更新, 这是非常恐怖的消耗.

  • 作者通过理论分析发现, Gt 随着梯度更新会逐渐趋于低秩, 本文建议一种 gradient low-rank projection (GaLore) 的方式更新:

    Wt+1=WtηG~t,G~t=Ptρt(PtTGtQt)QtT,

    其中 PtRm×r,QtRn×r,rm,n.

  • 即 梯度转移到低秩空间 -> 在低秩空间中完成 ρt -> 恢复到原空间. 于是在整个训练过程中, 我们只需要缓存这些投影矩阵即可. 如下是 Adam 的一个例子 (只用了一半的投影):

  • 收敛性是容易理解的, 每一步更新都相当于:

    φt(W^t),W^t=stop-gradient(Wt)+PW~tQT,W~tRr×r.

  • W~tφt=PTGtQ,

    此时便有:

    W^t+1=W^t+PΔW~QT=W^tηPρt(PTGtQ)QT.

posted @   馒头and花卷  阅读(83)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
历史上的今天:
2022-08-27 DropEdge: Towards Deep Graph Convolutional Networks on Node Classification
点击右上角即可分享
微信分享提示