torch.stack和torch.cat区别
结论
torch.stack ,类似并联,结果是在相应的维度会增加一维; 所以要求每个tensor大小相等
torch.cat 类似串联,结果是在相应的维度值会增加,所以要求扩展的维度大小相等
1.torch.stack
l = []
for i in range(0,3):
x = torch.rand(2,3)
l.append(x)
print(l)
x = torch.stack(l,dim=0)
print(x.size())
z = torch.stack(l,dim=1)
print(z.size())
output:
[tensor([[0.3615, 0.9595, 0.5895],
[0.8202, 0.6924, 0.4683]]), tensor([[0.0988, 0.3804, 0.5348],
[0.0712, 0.4715, 0.1307]]), tensor([[0.1635, 0.4716, 0.1728],
[0.8023, 0.9664, 0.4934]])]
torch.Size([3, 2, 3])
torch.Size([2, 3, 3])
2.torch.cat
x = torch.randn(2, 3)
torch.cat((x, x, x), 0)
x
tensor([[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497]])
tensor([[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497],
[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497],
[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497]])
torch.cat((x, x, x), 1)
tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580,
-1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034,
-0.5790, 0.1497]])