FlashAttention简介

前置知识

在GPU进行矩阵运算的时候,内部的运算单元具有和CPU类似的存储金字塔。

img

如果采用经典的Attention的计算方式,需要保存中间变量S和注意力矩阵O,这样子会产生很大的现存占用,并且这些数据的传输也会占用很多带宽和内存。

img

FlashAttention采用分块的方式来进行计算,这样子就可以减少中间变量的存储,同时也可以减少数据的传输。

img

具体的思想是改变Attention的运算顺序,标准是先计算 S=QK,O=Softmax(S),R=OV.
FlashAttention的计算顺序是先计算 R=OV,S=QK,O=Softmax(S).在这个过程中需要保存一些变量用于最终计算Softmax,并且在计算过程中进行分块,利用SRAM的带块,减少HBM的使用。
具体算法如下(ForWard):

img

img

posted @   chenfengshijie  阅读(149)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· DeepSeek 开源周回顾「GitHub 热点速览」
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
点击右上角即可分享
微信分享提示