Loading

优雅地实现多头自注意力——使用einsum(爱因斯坦求和)进行矩阵运算

einsum函数说明

pytorch文档说明:\(torch.einsum(equation, **operands)\) 使用基于爱因斯坦求和约定的符号,将输入operands的元素沿指定的维数求和。einsum允许计算许多常见的多维线性代数阵列运算,方法是基于爱因斯坦求和约定以简写格式表示它们。主要是省略了求和号,总体思路是在箭头左边用一些下标标记输入operands的每个维度,并在箭头右边定义哪些下标是输出的一部分。通过将operands元素与下标不属于输出的维度的乘积求和来计算输出。其方便之处在于可以直接通过求和公式写出运算代码。

# 矩阵乘法例子引入
a = torch.rand(2,3)
b = torch.rand(3,4)
c = torch.einsum("ik,kj->ij", [a, b])
# 等价操作 torch.mm(a, b)

两个基本概念,自由索引/自由标(Free indices)和求和索引/哑标(Summation indices):

  • 自由索引,出现在箭头右边的索引
  • 求和索引,只出现在箭头左边的索引,表示中间计算结果需要这个维度上求和之后才能得到输出,

接着是介绍三条基本规则:

  • 规则一,equation 箭头左边,在不同输入之间重复出现的索引表示,把输入张量沿着该维度做乘法操作,比如还是以上面矩阵乘法为例, "ik,kj->ij",k 在输入中重复出现,所以就是把 a 和 b 沿着 k 这个维度作相乘操作;
  • 规则二,只出现在 equation 箭头左边的索引,表示中间计算结果需要在这个维度上求和,也就是上面提到的求和索引;
  • 规则三,equation 箭头右边的索引顺序可以是任意的,比如上面的 "ik,kj->ij" 如果写成 "ik,kj->ji",那么就是返回输出结果的转置,用户只需要定义好索引的顺序,转置操作会在 einsum 内部完成。

两条特殊规则:

  • equation 可以不写包括箭头在内的右边部分,那么在这种情况下,输出张量的维度会根据默认规则推导。就是把输入中只出现一次的索引取出来,然后按字母表顺序排列,比如上面的矩阵乘法 "ik,kj->ij" 也可以简化为 "ik,kj",根据默认规则,输出就是 "ij" 与原来一样;
  • equation 中支持 "..." 省略号,用于表示用户并不关心的索引,详见下方转置例子

单操作数

获取对角线元素diagonal

einsum 可以不做求和。举个例子,获取二维方阵的对角线元素,结果放入一维向量。

\[A_i = B_{ii} \]

上面,A 是一维向量,B 是二维方阵。使用 einsum 记法,可以写作 ii->i

torch.einsum('ii->i', torch.randn(4, 4))

# 以下操作互相等价
a = torch.randn(4,4)
c = torch.einsum('ii->i', a)
c = torch.diagonal(a, 0)


迹trace

求解矩阵的迹(trace),即对角线元素的和。

\[t = \Sigma_{i=1}^{n} A_{ii} \]

t 是常量,A 是二维方阵。按照前面的做法,省略 ΣΣ,左右两边对调,省去矩阵和 t,剩下的就是ii->或省略箭头ii

torch.einsum('ii', torch.randn(4, 4))

矩阵转置

\[A_{ij} = B_{ji} \]

A 和 B 都是二维方阵。einsum 可以表达为 ij->ji

torch.einsum('ij -> ji',a)

pytorch 中,还支持省略前面的维度。比如,只转置最后两个维度,可以表达为 ...ij->...ji。下面展示了一个含有四个二维矩阵的三维矩阵,转置三维矩阵中的每个二维矩阵。

A = torch.randn(2, 3, 4, 5)
torch.einsum('...ij->...ji', A).shape
# torch.Size([2, 3, 5, 4])

# 等价操作
A.permute(0,1,3,2)
A.transpose(2,3)

求和

\[b=\sum_{i} \sum_{j} A_{i j}=A_{i j} \]

a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->', [a])
tensor(15.)

列求和:

\[b_{j}=\sum_{i} A_{i j}=A_{i j} \]

a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->j', [a])
tensor([ 3.,  5.,  7.])

# 等价操作
torch.sum(a, 0) # (dim参数0) means the dimension or dimensions to reduce.

双操作数

矩阵乘法

\[A_{ij} = \Sigma_{k=1}^{n} B_{ik} C_{kj} \]

第一个学习的 einsum 表达式是,ik,kj->ij。前面提到过,爱因斯坦求和记法可以理解为懒人求和记法。将上述公式中的 ΣΣ 去掉,并且将左右两边对调一下,省去矩阵之后,剩下的就是 ik,kj->ij 了。

torch.einsum('ik,kj->ij', a, b) 

# 可用两个矩阵测试以下矩阵乘法操作互相等价
a = torch.randn(2,3)
b = torch.randn(3,4)
c = torch.matmul(a,b)
c = torch.einsum('ik,kj->ij', a, b)
c = a.mm(b) 
c = torch.mm(a, b) 
c = a @ b

矩阵-向量相乘

\[c_{i}=\sum_{k} A_{i k} b_{k}=A_{i k} b_{k} \]

a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
torch.einsum('ik,k->i', [a, b])

tensor([  5.,  14.])

批量矩阵乘 batch matrix multiplication

\[C_{bik}=\sum_{k} A_{bij} B_{bjk}=A_{bij} B_{bjk} \]

>>> As = torch.randn(3,2,5)
>>> Bs = torch.randn(3,5,4)
>>> torch.einsum('bij,bjk->bik', As, Bs)
tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
        [-1.6706, -0.8097, -0.8025, -2.1183]],

        [[ 4.2239,  0.3107, -0.5756, -0.2354],
        [-1.4558, -0.3460,  1.5087, -0.8530]],

        [[ 2.8153,  1.8787, -4.3839, -1.2112],
        [ 0.3728, -2.1131,  0.0921,  0.8305]]])

# 等价操作
torch.bmm(As, Bs)

向量内积 dot

\[c=\sum_{i} a_{i} b_{i}=a_{i} b_{i} \]

a = torch.arange(3)
b = torch.arange(3,6)  # [3, 4, 5]
torch.einsum('i,i->', [a, b])
# tensor(14.)

# 等价操作
torch.dot(a, b)

矩阵内积 dot

\[c=\sum_{i} \sum_{j} A_{i j} B_{i j}=A_{i j} B_{i j} \]

a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->', [a, b])
tensor(145.)

哈达玛积

\[C_{i j}=A_{i j} B_{i j} \]

a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->ij', [a, b])
tensor([[  0.,   7.,  16.],
        [ 27.,  40.,  55.]])

外积 outer

\[C_{i j}=a_{i} b_{j} \]

a = torch.arange(3)
b = torch.arange(3,7)
torch.einsum('i,j->ij', [a, b])

tensor([[  0.,   0.,   0.,   0.],
        [  3.,   4.,   5.,   6.],
        [  6.,   8.,  10.,  12.]])

einsum其他规则和例子判断:

  • 输入中多次出现的字符,将被用作求和。例子,kj,ji 完整的表达式是 kj,ji->ik,矩阵乘法再相乘。
  • 输出可以指定,但是输出中的每个字符必须在输入中出现至少一次,输出的每个字符在输出中只能出现最多一次。例子,ab->aa 是非法的,ab->c 是非法的,ab->a 是合法的。
  • 省略符 ... 是用来跳过部分维度。例子,...ij,...jk 表示 batch 矩阵乘法。
  • 在输出没有指定的情况下,省略符优先级高于普通字符。例子,b...a 完整的表达式是 b...a->...ab,可以将一个形状为 (a,b,c) 的矩阵变为形状为 (b,c,a) 的矩阵。
  • 允许多个矩阵输入,表达式中使用逗号分开不同矩阵输入的下标。例子,i,i,i 表示将三个一维向量按位相乘,并相加。
  • 除了箭头,其他任何地方都可以加空格。例子,i j , j k -> ik 是合法的,ij,jk - > ik 是非法的。
  • 输入的表达式,维度需要和输入的矩阵对上,不能多也不能少。比如一个 shape 为 (4,3,3) 的矩阵,表达式 ab->a 是非法的,abc-> 是合法的。

实际使用

实现multi headed attention

https://nn.labml.ai/transformers/mha.html
如何优雅地实现多头自注意力

计算注意力score:

\[Q K^{\top} or S_{i j b h}=\sum_{d} Q_{i b h d} K_{j b h d} \]

# q k v均为 [seq_len, batch_size, heads, d_k] 
torch.einsum('ibhd,jbhd->ijbh', query, key) # 理解为ibhd,jbhd->ibhj->ijbh

计算attention输出:

\[\underset{\text { seq }}{\operatorname{softmax}}\left(\frac{Q K^{\top}}{\sqrt{d_{k}}}\right) V \]

# attn [seq_len, seq_len, batch_size, heads]
# value [seq_len, batch_size, heads, d_k] 

x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
# x [seq_len, batch_size, heads, d_k]


参考文献:

https://zhuanlan.zhihu.com/p/361209187
如何优雅地实现多头自注意力
https://rockt.github.io/2018/04/30/einsum **

posted @ 2022-05-08 11:50  MapleTx  阅读(3675)  评论(0编辑  收藏  举报