einops方法
该方法可以快速实现矩阵的快速变化。
import torch
import torch.nn as nn
from einops import rearrange # 快速矩阵变化
class TestAttentionQKV:
def __init__(self, dim=64, heads=8, dim_head=64):
inner_dim = dim_head * heads
self.dim = dim
self.heads = heads
self.scale = dim_head ** -0.5
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
def forward(self, x):
print(self.to_qkv(x).shape)
qkv = self.to_qkv(x).chunk(3, dim=-1)
print(qkv[0].shape)
print(qkv[1].shape)
print(qkv[2].shape)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
print(q.shape)
print(k.shape)
print(v.shape)
print(self.scale)
if __name__ == '__main__':
input_tensor = torch.rand((100, 20, 64))
torch.Tensor([1.0, 2.0])
test_attention = TestAttentionQKV()
# print(input_tensor)
test_attention.forward(input_tensor)