einsum爱因斯坦求和

最近因为vision transformer里的pytorch代码,看到了torch.einsum(np.einsum同理)这个操作,简直是神了;

比如

t = torch.randn(2,4,3)
q, k, v = tuple(rearrange(t, 'b t (d k) -> k b t d ', k=3))
print(q,'\n',k)

>>>
tensor([[[-0.9011],
         [-0.2627],
         [ 0.4202],
         [-0.3396]],

        [[ 0.0530],
         [ 0.5980],
         [ 0.1464],
         [ 0.7939]]]) 
 tensor([[[-1.0567],
         [ 0.0425],
         [-0.2160],
         [-2.2235]],

        [[ 0.3932],
         [-0.5011],
         [ 0.0748],
         [-1.3025]]])

可以看到这里生成了transformer里的q,k,v,维度是(2,4,1),维度含义分别是 (batch_size, token,dim),然后要做一个q*k^T的向量外积

scaled_dot_prod = torch.einsum('b i d , b j d -> b i j', q, k)

scaled_dot_prod
>>>tensor([[[ 0.9523, -0.0383,  0.1947,  2.0037],
         [ 0.2776, -0.0112,  0.0568,  0.5842],
         [-0.4440,  0.0179, -0.0908, -0.9342],
         [ 0.3588, -0.0144,  0.0734,  0.7551]],

        [[ 0.0208, -0.0265,  0.0040, -0.0690],
         [ 0.2351, -0.2996,  0.0447, -0.7789],
         [ 0.0575, -0.0733,  0.0109, -0.1906],
         [ 0.3122, -0.3978,  0.0594, -1.0341]]])

注意,这里的q和k都是同一维度,不用像原来做矩阵乘法那样要维度对应,而是可以直接指定维度去对应地乘;

因此,这里把k换到(2,1,4)的维度然后去和q乘,也是可以的,例如:

k_ = rearrange(k,'b t d -> b d t')
k_
a_scaled_dot_prod = torch.einsum('b i d , b d j -> b i j', q, k_)

a_scaled_dot_prod

>>>
tensor([[[ 0.9523, -0.0383,  0.1947,  2.0037],
         [ 0.2776, -0.0112,  0.0568,  0.5842],
         [-0.4440,  0.0179, -0.0908, -0.9342],
         [ 0.3588, -0.0144,  0.0734,  0.7551]],

        [[ 0.0208, -0.0265,  0.0040, -0.0690],
         [ 0.2351, -0.2996,  0.0447, -0.7789],
         [ 0.0575, -0.0733,  0.0109, -0.1906],
         [ 0.3122, -0.3978,  0.0594, -1.0341]]])

 参考:https://zhuanlan.zhihu.com/p/74462893

posted @ 2021-03-04 13:31  嶙羽  阅读(332)  评论(0编辑  收藏  举报