torch.einsum 的用法实例
torch 处理 tensor 张量的广播,使用 einsum 函数,摘录一段使用代码,并分析用法
# In[6]:
img_gray_weighted_fancy = torch.einsum('...chw,c->...hw', img_t, weights)
batch_gray_weighted_fancy = torch.einsum('...chw,c->...hw', batch_t, weights)
batch_gray_weighted_fancy.shape
# Out[6]:
torch.Size([2, 5, 5])
这段代码利用 einsum
函数来进行张量运算,并将每个通道的图像加权转换为灰度图像。einsum
的使用使得代码简洁且易于理解。:
代码解读
-
img_gray_weighted_fancy = torch.einsum('...chw,c->...hw', img_t, weights)
img_t
是一个图像张量,其形状假设为(3, 5, 5)
,表示3个通道的5x5图像(如RGB图像)。weights
是一个权重张量,其形状为(3,)
,对应于每个通道的权重。'...chw'
表示任意形状的张量(...
),其中c
是通道维度,h
是高度维度,w
是宽度维度。因为没有前导维度,img_t
的形状具体为(3, 5, 5)
。'c->...hw'
表示将通道维度c
与weights
的权重值相乘,并将其求和后,留下高度和宽度维度(h, w
)。最终结果是一个没有通道维度的二维张量,其形状为(5, 5)
,即灰度图像。
总结:
img_gray_weighted_fancy
是一个5x5的灰度图像,原来的RGB图像被weights
加权后生成。 -
batch_gray_weighted_fancy = torch.einsum('...chw,c->...hw', batch_t, weights)
batch_t
是一个批量图像张量,假设形状为(2, 3, 5, 5)
,表示2个样本、每个样本有3个通道(如RGB),每个通道是5x5的图像。- 依然使用
weights
作为权重张量,形状为(3,)
,对应每个通道的权重。 '...chw'
这里的...
表示前导的维度,即批量维度,这个索引表达式允许处理不同大小的张量。'c->...hw'
指示要将通道维度c
加权求和,结果保留批量、高度和宽度维度。对于批量图像,每个样本都会生成一个灰度图像。
总结:
batch_gray_weighted_fancy
的形状为(2, 5, 5)
,表示2个样本对应的5x5灰度图像。 -
batch_gray_weighted_fancy.shape
- 打印
batch_gray_weighted_fancy
的形状,得到torch.Size([2, 5, 5])
。
总结:最终输出的形状表明,处理后的批量图像变为2个5x5的灰度图像,每个样本的RGB图像都通过通道加权转换为了灰度图像。
- 打印
总览
einsum
表达式'...chw,c->...hw'
的含义是:对张量中指定的通道维度c
进行加权求和,去除通道维度,保留批量、行和列的维度。- 这段代码通过
einsum
函数简化了将彩色图像(多通道)转换为灰度图像的过程,避免了显式的广播和求和操作,极大地提升了代码的简洁性和可读性。
提问
上述代码"...chw,c->...hw"中,为什么两次使用 ... ,若变为"chw,c->...hw"可以吗?
回答:
在 einsum
表达式中,...
是一个通配符,表示任意数量的前导维度。这里的 ...
可以匹配0个或多个维度。表达式中的 ...chw,c->...hw
使用 ...
是为了处理任意数量的前导维度,而不仅仅是一个具体的维度。
为什么使用 ...
两次?
-
'...chw,c->...hw'
中的...chw
:- 这里的
...
可以匹配batch_t
的前导维度。例如,如果batch_t
的形状是(2, 3, 5, 5)
,那么...
匹配的就是2
。 chw
表示通道、高度、宽度,其中c
是通道维度,h
是高度维度,w
是宽度维度。
- 这里的
-
'...chw,c->...hw'
中的...hw
:- 表示结果张量中保留前导维度
...
和高度、宽度hw
,但去除了通道维度c
。
- 表示结果张量中保留前导维度
这意味着 ...
通配符可以让这个表达式更通用,无论张量前面有多少个维度,它都可以适用。比如:
- 对于一个形状为
(3, 5, 5)
的图像张量,...chw
会匹配(3, 5, 5)
,...hw
会匹配(5, 5)
。 - 对于一个形状为
(2, 3, 5, 5)
的批量图像张量,...chw
会匹配(2, 3, 5, 5)
,...hw
会匹配(2, 5, 5)
。
如果改为 'chw,c->...hw'
可以吗?
不可以。原因如下:
-
没有通配符
...
来匹配前导维度:- 如果你使用
'chw,c->...hw'
,这里chw
固定表示只有3个维度:通道、高度、宽度。 - 这个表达式不能匹配有前导维度的情况(如批量维度),即它只能处理形状精确为
(3, h, w)
的张量。
- 如果你使用
-
无法处理批量张量:
- 如果
batch_t
的形状是(2, 3, 5, 5)
,chw
不能正确处理这个形状,因为它缺少了批量维度(2)
,结果会导致错误。 ...
通配符的使用允许einsum
表达式能够自动适应任意数量的前导维度(例如批量维度),使得代码更加通用。
- 如果
总结
使用 ...
通配符两次的原因是为了让这个 einsum
表达式适用于不同的张量形状,无论是否有前导维度(如批量维度)。这使得代码更加灵活和通用。如果你去掉其中的一个 ...
,表达式的适用性将变得更窄,无法处理具有额外维度的张量。
问:...chw,c->...hw 变成 ...chw,c->hw,会有什么后果
答:
如果将 einsum
表达式从 ...chw,c->...hw
改为 ...chw,c->hw
,会导致运算中前导维度(如批量维度)被丢弃或移除,从而改变了结果张量的形状和语义。下面详细解释其后果。
原始表达式:...chw,c->...hw
...chw
:匹配任意数量的前导维度(...
),以及通道、高度、宽度。c->...hw
:消除了通道维度c
,保留了前导维度...
和高度、宽度hw
。
对于形状为 (2, 3, 5, 5)
的 batch_t
张量(其中 2
是批量维度),运算后得到形状为 (2, 5, 5)
的张量。前导维度(批量维度)被保留。
修改后的表达式:...chw,c->hw
...chw
:同样匹配任意数量的前导维度以及通道、高度、宽度。c->hw
:这里的hw
只保留了高度和宽度维度,完全忽略了前导维度...
。
后果:
-
前导维度丢失:
- 前导维度
...
(例如批量维度)被丢弃。 - 如果输入张量有前导维度(如批量维度
2
),这些维度在结果中将不再存在。
- 前导维度
-
结果形状:
- 假设
batch_t
的形状是(2, 3, 5, 5)
:...chw
中的...
匹配2
(批量维度)。- 由于结果只保留
hw
(高度和宽度),最终输出的形状将是(5, 5)
,即没有了批量维度。
- 这意味着,无论输入张量的前导维度是什么,最终结果都是将每个批次的图像折叠成一个单独的高度-宽度图像,导致批量信息丢失。
- 假设
-
语义混乱:
- 在深度学习处理中,保留批量维度通常是非常重要的。批量维度的丢失意味着无法将结果与原始输入数据一一对应。
- 如果你期望处理一组图像并保留每张图像的结果,但由于误操作丢失了批量维度,那么在后续步骤中处理这些结果会非常混乱。
结论
将 einsum
表达式从 ...chw,c->...hw
改为 ...chw,c->hw
会导致丢失前导维度(如批量维度)。如果前导维度被丢弃,输出将不再保留批次的结构信息,这在许多情况下可能是不可取的。通常你需要保留前导维度,除非你的具体应用场景明确不需要它们。