torch.cat() 与 torch.stack() 的区别
1. torch.cat()
torch.cat(tensors, dim=0)
在给定维度中拼接张量序列。
参数:
tensors
:张量序列。dim
:拼接张量序列的维度。
import torch
a = torch.rand(2, 3)
b = torch.rand(2, 3)
c = torch.cat((a, b))
print(a.size(), b.size(), c.size())
torch.Size([2, 3]) torch.Size([2, 3]) torch.Size([4, 3])
可以看出,\(a、b、c\) 都是二维。
张量序列必须具有相同大小:
d = torch.rand(2, 4)
print(torch.cat((a, d)))
RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 3 but got size 4 for tensor number 1 in the list.
具体拼接:
print(a)
print(torch.cat((a, a, a), dim=0))
print(torch.cat((a, a, a), dim=1))
tensor([[0.2381, 0.7100, 0.8150],
[0.5190, 0.5829, 0.9186]])
tensor([[0.2381, 0.7100, 0.8150],
[0.5190, 0.5829, 0.9186],
[0.2381, 0.7100, 0.8150],
[0.5190, 0.5829, 0.9186],
[0.2381, 0.7100, 0.8150],
[0.5190, 0.5829, 0.9186]])
tensor([[0.2381, 0.7100, 0.8150, 0.2381, 0.7100, 0.8150, 0.2381, 0.7100, 0.8150],
[0.5190, 0.5829, 0.9186, 0.5190, 0.5829, 0.9186, 0.5190, 0.5829, 0.9186]])
2. torch.stack()
torch.stack(tensors, dim=0)
沿新维度拼接张量。
参数:
tensors
:张量序列dim
:要插入的维度。
import torch
a = torch.rand((2, 3))
b = torch.rand((2, 3))
c = torch.stack((a, b))
print(a.size(), b.size(), c.size())
torch.Size([2, 3]) torch.Size([2, 3]) torch.Size([2, 2, 3])
可以看出,\(a、b\) 是二维,\(c\) 是三维。
张量序列必须具有相同大小:
d = torch.rand(2, 4)
print(torch.stack((a, d)))
RuntimeError: stack expects each tensor to be equal size, but got [2, 3] at entry 0 and [2, 4] at entry 1
具体拼接:
x = torch.arange(1, 7).reshape((3, 2))
y = torch.arange(10, 70, 10).reshape((3, 2))
z = torch.arange(100, 700, 100).reshape((3, 2))
print(x)
print(y)
print(z)
tensor([[1, 2],
[3, 4],
[5, 6]])
tensor([[10, 20],
[30, 40],
[50, 60]])
tensor([[100, 200],
[300, 400],
[500, 600]])
m = torch.stack((x,y,z))
print(m)
tensor([[[ 1, 2],
[ 3, 4],
[ 5, 6]],
[[ 10, 20],
[ 30, 40],
[ 50, 60]],
[[100, 200],
[300, 400],
[500, 600]]])
n = torch.stack((x,y,z), 1)
print(n)
tensor([[[ 1, 2],
[ 10, 20],
[100, 200]],
[[ 3, 4],
[ 30, 40],
[300, 400]],
[[ 5, 6],
[ 50, 60],
[500, 600]]])
h = torch.stack((x,y,z), 2)
print(h)
tensor([[[ 1, 10, 100],
[ 2, 20, 200]],
[[ 3, 30, 300],
[ 4, 40, 400]],
[[ 5, 50, 500],
[ 6, 60, 600]]])