torch.einsum 的计算过程
概论
a = torch.randn(3, 2, 2)
b = torch.randn(3)
c = torch.einsum('...chw,c->...hw', a, b)
上面的 einsum 如何计算的?
简单说,把 b 广播为 a 的形状,然后做矩阵乘法,即逐位相乘运算,注意,不是点积,是逐位的相乘运算。
注:这里符合背景需求,背景是,a 是深度学习的某个张量,b是a的权重,要求 a 的每一个元素都要乘以权重 b ,来得到实际有效的值。
然后,再把矩阵乘积的结果逐位相加后,得到最后结果,同时也去掉了维度c。
运算过程
具体运算细节如下:
为了详细解释 c = torch.einsum('...chw,c->...hw', a, b)
的计算过程,我们可以逐步分析每个部分的运算,并通过一个具体的例子说明结果的产生过程。
1. 张量 a
和 b
的形状与内容
a
是一个形状为(3, 2, 2)
的张量,假设其值为:a = torch.tensor([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]], [[0.9, 1.0], [1.1, 1.2]]])
b
是一个形状为(3,)
的张量,假设其值为:b = torch.tensor([2.0, 3.0, 4.0])
2. einsum
表达式 '...chw,c->...hw'
解析
-
...chw
:...
匹配任意数量的前导维度,在本例中没有前导维度。c
对应的是第一个维度(形状为3)。h
对应第二个维度(形状为2)。w
对应第三个维度(形状为2)。
-
c->...hw
:c
对应b
中的元素,作为缩并维度,它会在计算中被消除(通过乘法与求和操作)。...hw
表示最终保留高度(h)和宽度(w)维度。
3. 具体计算过程
步骤 1:广播 b
以匹配 a
的形状
张量 b
(形状 (3,)
)会被广播到形状 (3, 2, 2)
,与张量 a
对应的维度对齐:
b_broadcasted = torch.tensor([[[2.0, 2.0],
[2.0, 2.0]],
[[3.0, 3.0],
[3.0, 3.0]],
[[4.0, 4.0],
[4.0, 4.0]]])
步骤 2:逐元素相乘
a
和广播后的 b
逐元素相乘,结果为:
result = a * b_broadcasted
result = torch.tensor([[[0.2, 0.4],
[0.6, 0.8]],
[[1.5, 1.8],
[2.1, 2.4]],
[[3.6, 4.0],
[4.4, 4.8]]])
步骤 3:沿着 c
维度求和
对上面的结果,沿着 c
维度(也就是第一个维度)求和:
- 对
(0.2, 1.5, 3.6)
这三个数在c
维度上相加,得到5.3
。 - 对
(0.4, 1.8, 4.0)
这三个数在c
维度上相加,得到6.2
。 - 以此类推,求得最终结果
c
:
c = torch.tensor([[5.3, 6.2],
[7.1, 8.0]])
4. 总结结果
最终的结果 c
是一个形状为 (2, 2)
的张量。每个元素是通过 a
中与 b
对应的元素逐元素相乘后,再沿 c
维度(通道维度)求和得到的。
因此,使用 torch.einsum('...chw,c->...hw', a, b)
后,得到的张量 c
对应的是去掉 c
维度后的加权和,仅保留高度 h
和宽度 w
维度的结果。
广播运算
问:b 广播为 b_broadcasted ,代码如何写
答
在 PyTorch 中,可以使用 unsqueeze
和 expand
方法来将张量 b
广播为与 a
形状匹配的 b_broadcasted
。具体代码如下:
import torch
# 原始张量 a 和 b
a = torch.randn(3, 2, 2)
b = torch.randn(3)
# 扩展 b 的维度,使其能够广播到与 a 形状匹配
b_broadcasted = b.unsqueeze(-1).unsqueeze(-1).expand_as(a)
print(b_broadcasted)
步骤解释:
unsqueeze(-1)
:将b
的形状从(3,)
变为(3, 1)
,即在最后一维添加一个大小为1
的新维度。unsqueeze(-1)
:再一次将形状从(3, 1)
变为(3, 1, 1)
,在最后再添加一个大小为1
的维度。expand_as(a)
:将b
广播到与a
相同的形状(3, 2, 2)
。
结果示例:
假设 b
的值为 [2.0, 3.0, 4.0]
,那么 b_broadcasted
将变为:
tensor([[[2.0, 2.0],
[2.0, 2.0]],
[[3.0, 3.0],
[3.0, 3.0]],
[[4.0, 4.0],
[4.0, 4.0]]])
这就是张量 b
在计算过程中如何被广播为 b_broadcasted
的代码实现。