优雅地实现多头自注意力——使用einsum(爱因斯坦求和)进行矩阵运算
einsum函数说明
pytorch文档说明:\(torch.einsum(equation, **operands)\) 使用基于爱因斯坦求和约定的符号,将输入operands的元素沿指定的维数求和。einsum允许计算许多常见的多维线性代数阵列运算,方法是基于爱因斯坦求和约定以简写格式表示它们。主要是省略了求和号,总体思路是在箭头左边用一些下标标记输入operands的每个维度,并在箭头右边定义哪些下标是输出的一部分。通过将operands元素与下标不属于输出的维度的乘积求和来计算输出。其方便之处在于可以直接通过求和公式写出运算代码。
# 矩阵乘法例子引入
a = torch.rand(2,3)
b = torch.rand(3,4)
c = torch.einsum("ik,kj->ij", [a, b])
# 等价操作 torch.mm(a, b)
两个基本概念,自由索引/自由标(Free indices)和求和索引/哑标(Summation indices):
- 自由索引,出现在箭头右边的索引
- 求和索引,只出现在箭头左边的索引,表示中间计算结果需要这个维度上求和之后才能得到输出,
接着是介绍三条基本规则:
- 规则一,equation 箭头左边,在不同输入之间重复出现的索引表示,把输入张量沿着该维度做乘法操作,比如还是以上面矩阵乘法为例, "ik,kj->ij",k 在输入中重复出现,所以就是把 a 和 b 沿着 k 这个维度作相乘操作;
- 规则二,只出现在 equation 箭头左边的索引,表示中间计算结果需要在这个维度上求和,也就是上面提到的求和索引;
- 规则三,equation 箭头右边的索引顺序可以是任意的,比如上面的 "ik,kj->ij" 如果写成 "ik,kj->ji",那么就是返回输出结果的转置,用户只需要定义好索引的顺序,转置操作会在 einsum 内部完成。
两条特殊规则:
- equation 可以不写包括箭头在内的右边部分,那么在这种情况下,输出张量的维度会根据默认规则推导。就是把输入中只出现一次的索引取出来,然后按字母表顺序排列,比如上面的矩阵乘法 "ik,kj->ij" 也可以简化为 "ik,kj",根据默认规则,输出就是 "ij" 与原来一样;
- equation 中支持 "..." 省略号,用于表示用户并不关心的索引,详见下方转置例子
单操作数
获取对角线元素diagonal
einsum 可以不做求和。举个例子,获取二维方阵的对角线元素,结果放入一维向量。
上面,A 是一维向量,B 是二维方阵。使用 einsum 记法,可以写作 ii->i
torch.einsum('ii->i', torch.randn(4, 4))
# 以下操作互相等价
a = torch.randn(4,4)
c = torch.einsum('ii->i', a)
c = torch.diagonal(a, 0)
迹trace
求解矩阵的迹(trace),即对角线元素的和。
t 是常量,A 是二维方阵。按照前面的做法,省略 ΣΣ,左右两边对调,省去矩阵和 t,剩下的就是ii->
或省略箭头ii
torch.einsum('ii', torch.randn(4, 4))
矩阵转置
A 和 B 都是二维方阵。einsum 可以表达为 ij->ji
。
torch.einsum('ij -> ji',a)
pytorch 中,还支持省略前面的维度。比如,只转置最后两个维度,可以表达为 ...ij->...ji
。下面展示了一个含有四个二维矩阵的三维矩阵,转置三维矩阵中的每个二维矩阵。
A = torch.randn(2, 3, 4, 5)
torch.einsum('...ij->...ji', A).shape
# torch.Size([2, 3, 5, 4])
# 等价操作
A.permute(0,1,3,2)
A.transpose(2,3)
求和
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->', [a])
tensor(15.)
列求和:
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->j', [a])
tensor([ 3., 5., 7.])
# 等价操作
torch.sum(a, 0) # (dim参数0) means the dimension or dimensions to reduce.
双操作数
矩阵乘法
第一个学习的 einsum 表达式是,ik,kj->ij
。前面提到过,爱因斯坦求和记法可以理解为懒人求和记法。将上述公式中的 ΣΣ 去掉,并且将左右两边对调一下,省去矩阵之后,剩下的就是 ik,kj->ij
了。
torch.einsum('ik,kj->ij', a, b)
# 可用两个矩阵测试以下矩阵乘法操作互相等价
a = torch.randn(2,3)
b = torch.randn(3,4)
c = torch.matmul(a,b)
c = torch.einsum('ik,kj->ij', a, b)
c = a.mm(b)
c = torch.mm(a, b)
c = a @ b
矩阵-向量相乘
a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
torch.einsum('ik,k->i', [a, b])
tensor([ 5., 14.])
批量矩阵乘 batch matrix multiplication
>>> As = torch.randn(3,2,5)
>>> Bs = torch.randn(3,5,4)
>>> torch.einsum('bij,bjk->bik', As, Bs)
tensor([[[-1.0564, -1.5904, 3.2023, 3.1271],
[-1.6706, -0.8097, -0.8025, -2.1183]],
[[ 4.2239, 0.3107, -0.5756, -0.2354],
[-1.4558, -0.3460, 1.5087, -0.8530]],
[[ 2.8153, 1.8787, -4.3839, -1.2112],
[ 0.3728, -2.1131, 0.0921, 0.8305]]])
# 等价操作
torch.bmm(As, Bs)
向量内积 dot
a = torch.arange(3)
b = torch.arange(3,6) # [3, 4, 5]
torch.einsum('i,i->', [a, b])
# tensor(14.)
# 等价操作
torch.dot(a, b)
矩阵内积 dot
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->', [a, b])
tensor(145.)
哈达玛积
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->ij', [a, b])
tensor([[ 0., 7., 16.],
[ 27., 40., 55.]])
外积 outer
a = torch.arange(3)
b = torch.arange(3,7)
torch.einsum('i,j->ij', [a, b])
tensor([[ 0., 0., 0., 0.],
[ 3., 4., 5., 6.],
[ 6., 8., 10., 12.]])
einsum其他规则和例子判断:
- 输入中多次出现的字符,将被用作求和。例子,
kj,ji
完整的表达式是kj,ji->ik
,矩阵乘法再相乘。 - 输出可以指定,但是输出中的每个字符必须在输入中出现至少一次,输出的每个字符在输出中只能出现最多一次。例子,
ab->aa
是非法的,ab->c
是非法的,ab->a
是合法的。 - 省略符
...
是用来跳过部分维度。例子,...ij,...jk
表示 batch 矩阵乘法。 - 在输出没有指定的情况下,省略符优先级高于普通字符。例子,
b...a
完整的表达式是b...a->...ab
,可以将一个形状为(a,b,c)
的矩阵变为形状为(b,c,a)
的矩阵。 - 允许多个矩阵输入,表达式中使用逗号分开不同矩阵输入的下标。例子,
i,i,i
表示将三个一维向量按位相乘,并相加。 - 除了箭头,其他任何地方都可以加空格。例子,
i j , j k -> ik
是合法的,ij,jk - > ik
是非法的。 - 输入的表达式,维度需要和输入的矩阵对上,不能多也不能少。比如一个 shape 为
(4,3,3)
的矩阵,表达式ab->a
是非法的,abc->
是合法的。
实际使用
实现multi headed attention
https://nn.labml.ai/transformers/mha.html
如何优雅地实现多头自注意力
计算注意力score:
# q k v均为 [seq_len, batch_size, heads, d_k]
torch.einsum('ibhd,jbhd->ijbh', query, key) # 理解为ibhd,jbhd->ibhj->ijbh
计算attention输出:
# attn [seq_len, seq_len, batch_size, heads]
# value [seq_len, batch_size, heads, d_k]
x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
# x [seq_len, batch_size, heads, d_k]
参考文献:
https://zhuanlan.zhihu.com/p/361209187
如何优雅地实现多头自注意力
https://rockt.github.io/2018/04/30/einsum **