einsum 函数
einsum
是 Einstein summation 的缩写,即 爱因斯坦求和约定。einsum
函数源自 NumPy,后来在 PyTorch 等其他科学计算库中也得到了实现。它是一种强大而灵活的函数,可以用来处理各种张量运算,如矩阵乘法、转置、批量点积、内积、外积等。
爱因斯坦求和约定 (Einstein Summation Convention)
爱因斯坦求和约定是一种简洁的记号,用来表示张量运算中的求和操作。其核心思想是:如果在表达式中某个索引变量出现在多项式的两侧(即两个张量中),则默认对该索引进行求和,而不需要明确地写出求和符号。
einsum 的作用
einsum
函数允许用户通过指定输入张量的索引和输出张量的索引来定义各种张量运算。它不仅使表达式更加简洁和直观,还避免了对张量进行显式的维度操作(如转置、reshape 等),从而提高代码的可读性和效率。
如何使用 einsum
假设我们有两个矩阵 ( A ) 和 ( B ),它们的形状分别是 (i, j)
和 (j, k)
,常规的矩阵乘法可以表示为:
[ C_{ik} = \sum_j A_{ij} \times B_{jk} ]
在 einsum
中,这可以写作:
C = torch.einsum('ij,jk->ik', A, B)
这里的 'ij,jk->ik'
就是索引迷你语言,解释如下:
'ij'
是第一个矩阵A
的索引,表示A
的第一个维度用i
表示,第二个维度用j
表示。'jk'
是第二个矩阵B
的索引,表示B
的第一个维度用j
表示,第二个维度用k
表示。->ik
指定输出矩阵C
的索引结构,即结果C
的第一个维度是i
,第二个维度是k
。
einsum 的优点
- 灵活性:可以表示从简单的标量积到复杂的张量操作的一切运算。
- 简洁性:消除显式的循环和冗余操作,使代码更加简洁易读。
- 效率:通过在底层进行优化,
einsum
能够在不影响性能的前提下执行复杂的张量操作。
例子
-
矩阵乘法:
C = torch.einsum('ij,jk->ik', A, B)
-
内积(点积):
dot_product = torch.einsum('i,i->', x, y)
-
外积:
outer_product = torch.einsum('i,j->ij', x, y)
-
批量矩阵乘法:
batch_matrix_mult = torch.einsum('bij,bjk->bik', batch_A, batch_B)
在复杂的张量操作中,einsum
可以简化表达并提升代码的可读性和效率。