手撕Transformer之CrossAttention
特别感谢@lz.pan对本文的斧正.
我们来进行一个多头注意力的写。
首先直接开导:
import torch from torch import nn import torch.nn.functional as F import math
导完之后,很舒服,进行下一步。
class Multiheadattention(nn.Module): def __init__(self, input_dim, heads, d_model): super(Multiheadattention, self).__init__() self.d_model = d_model self.head_dim = self.d_model // heads self.heads_num = heads self.input_dim = input_dim self.to_q = nn.Linear(self.input_dim, self.d_model) # batch_size, input_dim, d_model self.to_k = nn.Linear(self.input_dim, self.d_model) # batch_size, input_dim, d_model self.to_v = nn.Linear(self.input_dim, self.d_model) # batch_size, input_dim, d_model self.to_out = nn.Linear(self.d_model, self.input_dim) # batch_size, input_dim, d_model def forward(self, q, k, v): bs = q.shape[0] q = self.to_q(q).view(bs, -1, self.heads_num, self.head_dim).transpose(1,2) # batch_size, seq_len, head_num, head_dim -> batch_size, head_num, seq_len, head_dim k = self.to_k(k).view(bs, -1, self.heads_num, self.head_dim).transpose(1,2) # batch_size, seq_len, head_num, head_dim -> batch_size, head_num, seq_len, head_dim v = self.to_v(v).view(bs, -1, self.heads_num, self.head_dim).transpose(1,2) # batch_size, seq_len, head_num, head_dim -> batch_size, head_num, seq_len, head_dim scores = torch.matmul(q, k.transpose(-2,-1)) / math.sqrt(self.head_dim) scores = F.softmax(scores, dim=-1) out = torch.matmul(scores, v) # batch_size, seq_len, head_num, head_dim out = out.transpose(1,2).contiguous().view(bs, -1, self.d_model) out = self.to_out(out) return out
最后我们进行测试:
heads = 2 batch_size = 4 input_dim = 32 multiheadattn = Multiheadattention(input_dim, heads, input_dim) q = torch.randn(batch_size,256,input_dim) k = torch.randn(batch_size,77,input_dim) v = torch.randn(batch_size,77,input_dim) out = multiheadattn(q,k,v) print(out.shape)
出来了torch.Size([4, 256, 32]),舒服了。
分类:
Torch基础系列
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· Manus爆火,是硬核还是营销?
· 终于写完轮子一部分:tcp代理 了,记录一下
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通