torch.stack 堆叠函数帮助理解多维数组
概论
在 PyTorch 中,torch.stack
函数用于在指定的维度上将一组张量堆叠起来。这个操作会在指定维度上创建一个新的维度,并将输入张量在该维度上进行堆叠。假设有两个形状相同的张量 a
和 b
,它们的形状都是 (2, 3, 4)
,那么在不同的 dim
参数下使用 torch.stack
会产生不同的结果。
以下是对这四种情况的解释:
-
c = torch.stack([a, b], dim=0)
- 在
dim=0
的位置上创建一个新的维度。 - 原始张量的形状为
(2, 3, 4)
,堆叠后形状变为(2, 3, 4)
前加上一个新的维度,形状变为(2, 2, 3, 4)
。 - 堆叠后张量
c
的形状为(2, 3, 4)
。 - 可以理解为把
a
和b
堆叠在第一个维度上,结果的第一个维度表示堆叠的张量数目。
- 在
-
d = torch.stack([a, b], dim=1)
- 在
dim=1
的位置上创建一个新的维度。 - 原始张量的形状为
(2, 3, 4)
,堆叠后形状变为(2, 2, 3, 4)
。 - 堆叠后张量
d
的形状为(2, 2, 3, 4)
。 - 可以理解为在第二个维度上插入一个新的维度,使每个原始张量的第一维度内的每个元素都变为包含两个子元素的张量。
- 在
-
e = torch.stack([a, b], dim=2)
- 在
dim=2
的位置上创建一个新的维度。 - 原始张量的形状为
(2, 3, 4)
,堆叠后形状变为(2, 3, 2, 4)
。 - 堆叠后张量
e
的形状为(2, 3, 2, 4)
。 - 这表示在第三个维度上创建新的维度,每个原始张量的前两个维度内的每个元素都变为包含两个子元素的张量。
- 在
-
f = torch.stack([a, b], dim=3)
- 在
dim=3
的位置上创建一个新的维度。 - 原始张量的形状为
(2, 3, 4)
,堆叠后形状变为(2, 3, 4, 2)
。 - 堆叠后张量
f
的形状为(2, 3, 4, 2)
。 - 这表示在第四个维度上创建新的维度,每个原始张量的前三个维度内的每个元素都变为包含两个子元素的张量。
- 在
总结来说,torch.stack
会在指定的 dim
维度上插入一个新的维度,使得原始张量在这个维度上堆叠起来。新的张量的形状将会比原始张量多一个维度,且堆叠方向对应于 dim
所指定的位置。
一言以蔽之
在指定维度 dim
上使用 torch.stack
堆叠时,会在该维度插入一个新的维度,使得原始张量在 dim
之前的所有维度的每个元素都变为包含堆叠张量数目的子元素的张量。
解读
这句话的意思是:
当你在指定的维度 dim
上使用 torch.stack
函数时,它会在张量的这个位置插入一个新的维度。这个新的维度会将原始张量沿着 dim
之前的所有维度中的每个元素扩展,使得这些元素现在在新增的维度上包含多个(等于你堆叠的张量个数)子元素。
具体来说,假设你有多个形状相同的张量,当你在某个维度 dim
上堆叠它们时,堆叠后的新张量在 dim
之前的每一个维度上的每个元素都会新增一个维度,用来存放你堆叠的张量。这些子元素的数量就是你堆叠的张量数目。
例如:
dim=0
:在第一个维度上堆叠,那么结果张量的第一个维度的大小就是堆叠的张量个数,每个子张量会在这个新的维度中排列。dim=1
:在第二个维度上堆叠,那么结果张量的第二个维度的大小就是堆叠的张量个数,原来第一维度的每个元素现在会包含多个子元素。- 以此类推,堆叠的位置决定了新的维度插入在哪里,以及原张量如何被扩展。
再探讨
若一个张量 a 形状是 (2, 3, 4),堆叠代码:torch.stack([a, a], dim=?)
当 dim=0 时,由于是第0维度,前面没有了,故把整个张量看作一个元素堆叠。
当 dim=1 时,前面有维度 2, 在这里插入新维度会把 (3, 4) 看作一个元素,进行堆叠。
当 dim=2 时,前面有维度(2, 3),这时会把(4, )看作一个元素,进行堆叠。
当 dim=3 时, 前面有维度(2, 3, 4),这是会把每一个元素看做一个元素,进行堆叠。
看代码,再读上文。
a = torch.randn(2, 3, 4)
a
tensor([[[ 0.4964, -0.2426, -0.4883, -0.9112],
[ 0.2928, 1.8061, -0.0770, -0.2761],
[-0.1384, 0.5872, 0.1957, 1.4741]],
[[-1.1077, 1.0878, 0.4793, 0.9741],
[ 2.0282, 0.7055, -0.0954, -0.3203],
[-0.7217, -1.1332, 0.0738, -0.8602]]])
b= torch.stack([a, a], dim=0)
b # 把整个a看作一个元素,进行堆叠
tensor([[[[ 0.4964, -0.2426, -0.4883, -0.9112],
[ 0.2928, 1.8061, -0.0770, -0.2761],
[-0.1384, 0.5872, 0.1957, 1.4741]],
[[-1.1077, 1.0878, 0.4793, 0.9741],
[ 2.0282, 0.7055, -0.0954, -0.3203],
[-0.7217, -1.1332, 0.0738, -0.8602]]],
[[[ 0.4964, -0.2426, -0.4883, -0.9112],
[ 0.2928, 1.8061, -0.0770, -0.2761],
[-0.1384, 0.5872, 0.1957, 1.4741]],
[[-1.1077, 1.0878, 0.4793, 0.9741],
[ 2.0282, 0.7055, -0.0954, -0.3203],
[-0.7217, -1.1332, 0.0738, -0.8602]]]])
b= torch.stack([a, a], dim=1)
b # 把a的(3, 4)部分看作一个元素,进行堆叠
tensor([[[[ 0.4964, -0.2426, -0.4883, -0.9112],
[ 0.2928, 1.8061, -0.0770, -0.2761],
[-0.1384, 0.5872, 0.1957, 1.4741]],
[[ 0.4964, -0.2426, -0.4883, -0.9112],
[ 0.2928, 1.8061, -0.0770, -0.2761],
[-0.1384, 0.5872, 0.1957, 1.4741]]],
[[[-1.1077, 1.0878, 0.4793, 0.9741],
[ 2.0282, 0.7055, -0.0954, -0.3203],
[-0.7217, -1.1332, 0.0738, -0.8602]],
[[-1.1077, 1.0878, 0.4793, 0.9741],
[ 2.0282, 0.7055, -0.0954, -0.3203],
[-0.7217, -1.1332, 0.0738, -0.8602]]]])
b= torch.stack([a, a], dim=2)
b #把a的(4,)部分看作一个元素,进行堆叠
tensor([[[[ 0.4964, -0.2426, -0.4883, -0.9112],
[ 0.4964, -0.2426, -0.4883, -0.9112]],
[[ 0.2928, 1.8061, -0.0770, -0.2761],
[ 0.2928, 1.8061, -0.0770, -0.2761]],
[[-0.1384, 0.5872, 0.1957, 1.4741],
[-0.1384, 0.5872, 0.1957, 1.4741]]],
[[[-1.1077, 1.0878, 0.4793, 0.9741],
[-1.1077, 1.0878, 0.4793, 0.9741]],
[[ 2.0282, 0.7055, -0.0954, -0.3203],
[ 2.0282, 0.7055, -0.0954, -0.3203]],
[[-0.7217, -1.1332, 0.0738, -0.8602],
[-0.7217, -1.1332, 0.0738, -0.8602]]]])
b= torch.stack([a, a], dim=3)
b # 把a的每个元素看作一个元素进行堆叠
tensor([[[[ 0.4964, 0.4964],
[-0.2426, -0.2426],
[-0.4883, -0.4883],
[-0.9112, -0.9112]],
[[ 0.2928, 0.2928],
[ 1.8061, 1.8061],
[-0.0770, -0.0770],
[-0.2761, -0.2761]],
[[-0.1384, -0.1384],
[ 0.5872, 0.5872],
[ 0.1957, 0.1957],
[ 1.4741, 1.4741]]],
[[[-1.1077, -1.1077],
[ 1.0878, 1.0878],
[ 0.4793, 0.4793],
[ 0.9741, 0.9741]],
[[ 2.0282, 2.0282],
[ 0.7055, 0.7055],
[-0.0954, -0.0954],
[-0.3203, -0.3203]],
[[-0.7217, -0.7217],
[-1.1332, -1.1332],
[ 0.0738, 0.0738],
[-0.8602, -0.8602]]]])
再总结
torcha.stack 把dim=? 指定插入维度后,把原有维度以插入维度为起点,看作一个整体,做为一个堆叠元素,进行堆叠。
例如,有 a 形状为 (2, 5, 8, 3) ,当dim=1时,对a进行切片以(5,8,3)为一个元素,进行堆叠。
故 torch.stack([a, a], dim=1)
的结果形状为:(2,2,5,8,3)