torch.einsum 的用法实例

torch 处理 tensor 张量的广播,使用 einsum 函数,摘录一段使用代码,并分析用法

# In[6]:
img_gray_weighted_fancy = torch.einsum('...chw,c->...hw', img_t, weights)
batch_gray_weighted_fancy = torch.einsum('...chw,c->...hw', batch_t, weights)
batch_gray_weighted_fancy.shape

# Out[6]:
torch.Size([2, 5, 5])

这段代码利用 einsum 函数来进行张量运算,并将每个通道的图像加权转换为灰度图像。einsum 的使用使得代码简洁且易于理解。:

代码解读

  1. img_gray_weighted_fancy = torch.einsum('...chw,c->...hw', img_t, weights)

    • img_t 是一个图像张量,其形状假设为 (3, 5, 5),表示3个通道的5x5图像(如RGB图像)。
    • weights 是一个权重张量,其形状为 (3,),对应于每个通道的权重。
    • '...chw' 表示任意形状的张量(...),其中 c 是通道维度,h 是高度维度,w 是宽度维度。因为没有前导维度,img_t 的形状具体为 (3, 5, 5)
    • 'c->...hw' 表示将通道维度 cweights 的权重值相乘,并将其求和后,留下高度和宽度维度(h, w)。最终结果是一个没有通道维度的二维张量,其形状为 (5, 5),即灰度图像。

    总结img_gray_weighted_fancy 是一个5x5的灰度图像,原来的RGB图像被 weights 加权后生成。

  2. batch_gray_weighted_fancy = torch.einsum('...chw,c->...hw', batch_t, weights)

    • batch_t 是一个批量图像张量,假设形状为 (2, 3, 5, 5),表示2个样本、每个样本有3个通道(如RGB),每个通道是5x5的图像。
    • 依然使用 weights 作为权重张量,形状为 (3,),对应每个通道的权重。
    • '...chw' 这里的 ... 表示前导的维度,即批量维度,这个索引表达式允许处理不同大小的张量。
    • 'c->...hw' 指示要将通道维度 c 加权求和,结果保留批量、高度和宽度维度。对于批量图像,每个样本都会生成一个灰度图像。

    总结batch_gray_weighted_fancy 的形状为 (2, 5, 5),表示2个样本对应的5x5灰度图像。

  3. batch_gray_weighted_fancy.shape

    • 打印 batch_gray_weighted_fancy 的形状,得到 torch.Size([2, 5, 5])

    总结:最终输出的形状表明,处理后的批量图像变为2个5x5的灰度图像,每个样本的RGB图像都通过通道加权转换为了灰度图像。

总览

  • einsum 表达式 '...chw,c->...hw' 的含义是:对张量中指定的通道维度 c 进行加权求和,去除通道维度,保留批量、行和列的维度。
  • 这段代码通过 einsum 函数简化了将彩色图像(多通道)转换为灰度图像的过程,避免了显式的广播和求和操作,极大地提升了代码的简洁性和可读性。

提问

上述代码"...chw,c->...hw"中,为什么两次使用 ... ,若变为"chw,c->...hw"可以吗?

回答:

einsum 表达式中,... 是一个通配符,表示任意数量的前导维度。这里的 ... 可以匹配0个或多个维度。表达式中的 ...chw,c->...hw 使用 ... 是为了处理任意数量的前导维度,而不仅仅是一个具体的维度。

为什么使用 ... 两次?

  1. '...chw,c->...hw' 中的 ...chw

    • 这里的 ... 可以匹配 batch_t 的前导维度。例如,如果 batch_t 的形状是 (2, 3, 5, 5),那么 ... 匹配的就是 2
    • chw 表示通道、高度、宽度,其中 c 是通道维度,h 是高度维度,w 是宽度维度。
  2. '...chw,c->...hw' 中的 ...hw

    • 表示结果张量中保留前导维度 ... 和高度、宽度 hw,但去除了通道维度 c

这意味着 ... 通配符可以让这个表达式更通用,无论张量前面有多少个维度,它都可以适用。比如:

  • 对于一个形状为 (3, 5, 5) 的图像张量,...chw 会匹配 (3, 5, 5)...hw 会匹配 (5, 5)
  • 对于一个形状为 (2, 3, 5, 5) 的批量图像张量,...chw 会匹配 (2, 3, 5, 5)...hw 会匹配 (2, 5, 5)

如果改为 'chw,c->...hw' 可以吗?

不可以。原因如下:

  1. 没有通配符 ... 来匹配前导维度

    • 如果你使用 'chw,c->...hw',这里 chw 固定表示只有3个维度:通道、高度、宽度。
    • 这个表达式不能匹配有前导维度的情况(如批量维度),即它只能处理形状精确为 (3, h, w) 的张量。
  2. 无法处理批量张量

    • 如果 batch_t 的形状是 (2, 3, 5, 5)chw 不能正确处理这个形状,因为它缺少了批量维度 (2),结果会导致错误。
    • ... 通配符的使用允许 einsum 表达式能够自动适应任意数量的前导维度(例如批量维度),使得代码更加通用。

总结

使用 ... 通配符两次的原因是为了让这个 einsum 表达式适用于不同的张量形状,无论是否有前导维度(如批量维度)。这使得代码更加灵活和通用。如果你去掉其中的一个 ...,表达式的适用性将变得更窄,无法处理具有额外维度的张量。

问:...chw,c->...hw 变成 ...chw,c->hw,会有什么后果

答:
如果将 einsum 表达式从 ...chw,c->...hw 改为 ...chw,c->hw,会导致运算中前导维度(如批量维度)被丢弃或移除,从而改变了结果张量的形状和语义。下面详细解释其后果。

原始表达式:...chw,c->...hw

  • ...chw:匹配任意数量的前导维度(...),以及通道、高度、宽度。
  • c->...hw:消除了通道维度 c,保留了前导维度 ... 和高度、宽度 hw

对于形状为 (2, 3, 5, 5)batch_t 张量(其中 2 是批量维度),运算后得到形状为 (2, 5, 5) 的张量。前导维度(批量维度)被保留。

修改后的表达式:...chw,c->hw

  • ...chw:同样匹配任意数量的前导维度以及通道、高度、宽度。
  • c->hw:这里的 hw 只保留了高度和宽度维度,完全忽略了前导维度 ...

后果:

  1. 前导维度丢失

    • 前导维度 ...(例如批量维度)被丢弃。
    • 如果输入张量有前导维度(如批量维度 2),这些维度在结果中将不再存在。
  2. 结果形状

    • 假设 batch_t 的形状是 (2, 3, 5, 5)
      • ...chw 中的 ... 匹配 2(批量维度)。
      • 由于结果只保留 hw(高度和宽度),最终输出的形状将是 (5, 5),即没有了批量维度。
    • 这意味着,无论输入张量的前导维度是什么,最终结果都是将每个批次的图像折叠成一个单独的高度-宽度图像,导致批量信息丢失。
  3. 语义混乱

    • 在深度学习处理中,保留批量维度通常是非常重要的。批量维度的丢失意味着无法将结果与原始输入数据一一对应。
    • 如果你期望处理一组图像并保留每张图像的结果,但由于误操作丢失了批量维度,那么在后续步骤中处理这些结果会非常混乱。

结论

einsum 表达式从 ...chw,c->...hw 改为 ...chw,c->hw 会导致丢失前导维度(如批量维度)。如果前导维度被丢弃,输出将不再保留批次的结构信息,这在许多情况下可能是不可取的。通常你需要保留前导维度,除非你的具体应用场景明确不需要它们。

posted @ 2024-08-09 07:36  立体风  阅读(60)  评论(0编辑  收藏  举报