einsum求矩阵运算

einsum是什么

使用chatGPT解读官方的文档

Einsum允许使用基于Einstein求和约定的简写格式来计算许多常见的多维线性代数数组操作,这些操作可以表示为一个方程式。
具体格式的细节在下面进行描述,但是一般的思想是使用一些下标为输入操作数的每个维度进行标记,并且定义哪些下标是输出的一部分。
然后,通过沿着那些下标不是输出的维度对操作数的元素积进行求和来计算输出结果。
例如,可以使用einsum来计算矩阵乘法,如torch.einsum("ij,jk->ik", A, B)。
在这种情况下,j是求和下标,i和k是输出下标(有关更多详细信息,请参见下面的部分)。

因此描述的也比较清楚,总结而言就是:对每个维度进行标记,输出需要维度的结果,未出现的维度就会作为求和的维度被消除掉。

实现

下面引用了一下官网的例子,后面在多看看,多理解一下,还有点懵。

# trace
torch.einsum('ii', torch.randn(4, 4))

# diagonal
torch.einsum('ii->i', torch.randn(4, 4))

# outer product
x = torch.randn(5)
y = torch.randn(4)
torch.einsum('i,j->ij', x, y)

# batch matrix multiplication
As = torch.randn(3, 2, 5)
Bs = torch.randn(3, 5, 4)
torch.einsum('bij,bjk->bik', As, Bs)



# with sublist format and ellipsis
torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2])



# batch permute
A = torch.randn(2, 3, 4, 5)
torch.einsum('...ij->...ji', A).shape

# equivalent to torch.nn.functional.bilinear
A = torch.randn(3, 5, 4)
l = torch.randn(2, 5)
r = torch.randn(2, 4)
torch.einsum('bn,anm,bm->ba', l, A, r)
posted on 2023-06-07 15:48  蔚蓝色の天空  阅读(43)  评论(0编辑  收藏  举报