perf flash att

import torch
import torch.nn.functional as F
import math


def fmha(q, k, v, is_causal=True):
    return F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=is_causal)

def mha(q, k, v, is_causal=True, seq_length=1024):
    # seq_length = q.shape[2]
    attention_scores = torch.matmul(q, k.transpose(-2, -1)) / q.shape[-1] # -> (b, nh, t, t)
    # attention_scores = attention_scores.masked_fill(self.bias[:,:,:seq_length,:seq_length] == 0, float('-inf'))
    attention_probs = F.softmax(attention_scores, dim=-1)
    # attention_probs = self.attention_dropout(attention_probs) # -> (b, nh, t, t)
    out = torch.matmul(attention_probs, v) # -> (b, nh, t, hs)
    return out

q = torch.randn(2,16,1024,64).cuda()
k = torch.randn(2,16,1024,64).cuda()
v = torch.randn(2,16,1024,64).cuda()

for i in range(2):
    with torch.profiler.profile(
        activities=[
            # torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ]) as p:
        fmha(q, k, v)
    print(p.key_averages().table(
        sort_by="self_cuda_time_total", row_limit=-1))
print('--------')
for i in range(2):
    with torch.profiler.profile(
        activities=[
            # torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ]) as p:
        mha(q, k, v)
    print(p.key_averages().table(
        sort_by="self_cuda_time_total", row_limit=-1))

posted @ 2023-05-18 13:43  xytpai  阅读(30)  评论(0编辑  收藏  举报