torch.einsum 的计算过程

概论

a = torch.randn(3, 2, 2)
b = torch.randn(3)
c = torch.einsum('...chw,c->...hw', a, b)

上面的 einsum 如何计算的?
简单说,把 b 广播为 a 的形状,然后做矩阵乘法,即逐位相乘运算,注意,不是点积,是逐位的相乘运算。
注:这里符合背景需求,背景是,a 是深度学习的某个张量,b是a的权重,要求 a 的每一个元素都要乘以权重 b ,来得到实际有效的值。
然后,再把矩阵乘积的结果逐位相加后,得到最后结果,同时也去掉了维度c。

运算过程

具体运算细节如下:

为了详细解释 c = torch.einsum('...chw,c->...hw', a, b) 的计算过程,我们可以逐步分析每个部分的运算,并通过一个具体的例子说明结果的产生过程。

1. 张量 ab 的形状与内容

  • a 是一个形状为 (3, 2, 2) 的张量,假设其值为:
    a = torch.tensor([[[0.1, 0.2],
                       [0.3, 0.4]],
    
                      [[0.5, 0.6],
                       [0.7, 0.8]],
    
                      [[0.9, 1.0],
                       [1.1, 1.2]]])
    
  • b 是一个形状为 (3,) 的张量,假设其值为:
    b = torch.tensor([2.0, 3.0, 4.0])
    

2. einsum 表达式 '...chw,c->...hw' 解析

  • ...chw:

    • ... 匹配任意数量的前导维度,在本例中没有前导维度。
    • c 对应的是第一个维度(形状为3)。
    • h 对应第二个维度(形状为2)。
    • w 对应第三个维度(形状为2)。
  • c->...hw:

    • c 对应 b 中的元素,作为缩并维度,它会在计算中被消除(通过乘法与求和操作)。
    • ...hw 表示最终保留高度(h)和宽度(w)维度。

3. 具体计算过程

步骤 1:广播 b 以匹配 a 的形状

张量 b(形状 (3,))会被广播到形状 (3, 2, 2),与张量 a 对应的维度对齐:

b_broadcasted = torch.tensor([[[2.0, 2.0],
                               [2.0, 2.0]],

                              [[3.0, 3.0],
                               [3.0, 3.0]],

                              [[4.0, 4.0],
                               [4.0, 4.0]]])

步骤 2:逐元素相乘

a 和广播后的 b 逐元素相乘,结果为:

result = a * b_broadcasted
result = torch.tensor([[[0.2, 0.4],
                        [0.6, 0.8]],

                       [[1.5, 1.8],
                        [2.1, 2.4]],

                       [[3.6, 4.0],
                        [4.4, 4.8]]])

步骤 3:沿着 c 维度求和

对上面的结果,沿着 c 维度(也就是第一个维度)求和:

  • (0.2, 1.5, 3.6) 这三个数在 c 维度上相加,得到 5.3
  • (0.4, 1.8, 4.0) 这三个数在 c 维度上相加,得到 6.2
  • 以此类推,求得最终结果 c
c = torch.tensor([[5.3, 6.2],
                  [7.1, 8.0]])

4. 总结结果

最终的结果 c 是一个形状为 (2, 2) 的张量。每个元素是通过 a 中与 b 对应的元素逐元素相乘后,再沿 c 维度(通道维度)求和得到的。

因此,使用 torch.einsum('...chw,c->...hw', a, b) 后,得到的张量 c 对应的是去掉 c 维度后的加权和,仅保留高度 h 和宽度 w 维度的结果。

广播运算

问:b 广播为 b_broadcasted ,代码如何写

在 PyTorch 中,可以使用 unsqueezeexpand 方法来将张量 b 广播为与 a 形状匹配的 b_broadcasted。具体代码如下:

import torch

# 原始张量 a 和 b
a = torch.randn(3, 2, 2)
b = torch.randn(3)

# 扩展 b 的维度,使其能够广播到与 a 形状匹配
b_broadcasted = b.unsqueeze(-1).unsqueeze(-1).expand_as(a)

print(b_broadcasted)

步骤解释:

  1. unsqueeze(-1):将 b 的形状从 (3,) 变为 (3, 1),即在最后一维添加一个大小为 1 的新维度。
  2. unsqueeze(-1):再一次将形状从 (3, 1) 变为 (3, 1, 1),在最后再添加一个大小为 1 的维度。
  3. expand_as(a):将 b 广播到与 a 相同的形状 (3, 2, 2)

结果示例:

假设 b 的值为 [2.0, 3.0, 4.0],那么 b_broadcasted 将变为:

tensor([[[2.0, 2.0],
         [2.0, 2.0]],

        [[3.0, 3.0],
         [3.0, 3.0]],

        [[4.0, 4.0],
         [4.0, 4.0]]])

这就是张量 b 在计算过程中如何被广播为 b_broadcasted 的代码实现。

posted @ 2024-08-09 09:30  立体风  阅读(66)  评论(0编辑  收藏  举报