import torch
import torch.nn.functional as F
import math
def fmha(q, k, v, is_causal=True):
return F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=is_causal)
def mha(q, k, v, is_causal=True, seq_length=1024):
# seq_length = q.shape[2]
attention_scores = torch.matmul(q, k.transpose(-2, -1)) / q.shape[-1] # -> (b, nh, t, t)
# attention_scores = attention_scores.masked_fill(self.bias[:,:,:seq_length,:seq_length] == 0, float('-inf'))
attention_probs = F.softmax(attention_scores, dim=-1)
# attention_probs = self.attention_dropout(attention_probs) # -> (b, nh, t, t)
out = torch.matmul(attention_probs, v) # -> (b, nh, t, hs)
return out
q = torch.randn(2,16,1024,64).cuda()
k = torch.randn(2,16,1024,64).cuda()
v = torch.randn(2,16,1024,64).cuda()
for i in range(2):
with torch.profiler.profile(
activities=[
# torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]) as p:
fmha(q, k, v)
print(p.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))
print('--------')
for i in range(2):
with torch.profiler.profile(
activities=[
# torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]) as p:
mha(q, k, v)
print(p.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))