手撕Transformer之CrossAttention

特别感谢@lz.pan对本文的斧正.


我们来进行一个多头注意力的写。

 

首先直接开导:

import torch
from torch import nn
import torch.nn.functional as F
import math

 

导完之后,很舒服,进行下一步。

复制代码
class Multiheadattention(nn.Module):
    def __init__(self, input_dim, heads, d_model):
        super(Multiheadattention, self).__init__()
        self.d_model = d_model
        self.head_dim = self.d_model // heads
        self.heads_num = heads
        self.input_dim = input_dim

        self.to_q = nn.Linear(self.input_dim, self.d_model)   # batch_size, input_dim, d_model
        self.to_k = nn.Linear(self.input_dim, self.d_model)   # batch_size, input_dim, d_model
        self.to_v = nn.Linear(self.input_dim, self.d_model)   # batch_size, input_dim, d_model
        self.to_out = nn.Linear(self.d_model, self.input_dim)   # batch_size, input_dim, d_model

    def forward(self, q, k, v):
        bs = q.shape[0]
        q = self.to_q(q).view(bs, -1, self.heads_num, self.head_dim).transpose(1,2) # batch_size, seq_len, head_num, head_dim -> batch_size, head_num, seq_len, head_dim 
        k = self.to_k(k).view(bs, -1, self.heads_num, self.head_dim).transpose(1,2) # batch_size, seq_len, head_num, head_dim -> batch_size, head_num, seq_len, head_dim 
        v = self.to_v(v).view(bs, -1, self.heads_num, self.head_dim).transpose(1,2) # batch_size, seq_len, head_num, head_dim -> batch_size, head_num, seq_len, head_dim 
        scores = torch.matmul(q, k.transpose(-2,-1)) / math.sqrt(self.head_dim)
        scores = F.softmax(scores, dim=-1)
        out = torch.matmul(scores, v)    # batch_size, seq_len, head_num, head_dim
        out = out.transpose(1,2).contiguous().view(bs, -1, self.d_model)
        out = self.to_out(out)
        return out
复制代码

 

最后我们进行测试:

复制代码
heads = 2
batch_size = 4
input_dim = 32

multiheadattn = Multiheadattention(input_dim, heads, input_dim)

q = torch.randn(batch_size,256,input_dim)
k = torch.randn(batch_size,77,input_dim)
v = torch.randn(batch_size,77,input_dim)

out = multiheadattn(q,k,v)
print(out.shape)
复制代码

 

出来了torch.Size([4, 256, 32]),舒服了。

posted @   老八蜜汁小憨包  阅读(722)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· Manus爆火,是硬核还是营销?
· 终于写完轮子一部分:tcp代理 了,记录一下
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
点击右上角即可分享
微信分享提示