torch.stack和torch.cat区别

结论
torch.stack ,类似并联,结果是在相应的维度会增加一维; 所以要求每个tensor大小相等
torch.cat 类似串联,结果是在相应的维度值会增加,所以要求扩展的维度大小相等

1.torch.stack

l = []
for i in range(0,3):
    x = torch.rand(2,3)
    l.append(x)
print(l)

x = torch.stack(l,dim=0)

print(x.size())

z = torch.stack(l,dim=1)
print(z.size())

output:
[tensor([[0.3615, 0.9595, 0.5895],
[0.8202, 0.6924, 0.4683]]), tensor([[0.0988, 0.3804, 0.5348],
[0.0712, 0.4715, 0.1307]]), tensor([[0.1635, 0.4716, 0.1728],
[0.8023, 0.9664, 0.4934]])]
torch.Size([3, 2, 3])
torch.Size([2, 3, 3])

2.torch.cat

 x = torch.randn(2, 3)
 torch.cat((x, x, x), 0)
 x

tensor([[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497]])
tensor([[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497],
[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497],
[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497]])

 torch.cat((x, x, x), 1)

tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580,
-1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034,
-0.5790, 0.1497]])

posted @ 2021-07-27 20:27  哈哈哈喽喽喽  阅读(222)  评论(0编辑  收藏  举报