torch.cat()和torch.stack()

torch.cat() 和 torch.stack()略有不同
torch.cat(tensors,dim=0,out=None)→ Tensor
torch.cat()对tensors沿指定维度拼接,但返回的Tensor的维数不会变,可理解为续接;
torch.stack(tensors,dim=0,out=None)→ Tensor
torch.stack()同样是对tensors沿指定维度拼接,但返回的Tensor会多一维,可理解为叠加;
————————————————
版权声明:本文为CSDN博主「进阶媛小吴」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/wuli_xin/article/details/118972316

torch.cat((a,b),dim=1)和torch.cat((a,b)axis=1)一样。

同理:torch.stack((a,b),dim=1)和torch.stack((a,b)axis=1)一样。

zz=torch.rand(100)#默认zz是列向量。而非行向量。

上述3行的情况,自己已经实际实验过。

 

 

 

 

 

结果为:

 

 上述行数相同d,c,在第一维度也即列上拼接时,能拼接成100行六列的tensor.

import torch
a=torch.rand(100)
b=torch.rand(100)
c=torch.rand((100,2))
d=torch.rand((100,2))
e=torch.rand((100,2))

ab,ab1,cd,cd1,cd2,cd3=torch.stack((a,b)),torch.stack((a,b),dim=0),torch.stack((c,d),dim=0),torch.stack((c,d),dim=1),torch.stack((c,d,e),axis=1),torch.stack((c,d,e),dim=-1)
'a.shape',a.shape,"b.shape",b.shape,"ab.shape",ab.shape,"ab1.shape",ab1.shape,'cd.shape',cd.shape,'cd1.shape',cd1.shape,'cd2.shape',cd2.shape,cd3.shape

  此处需注意的是:torch.stack((c,d,e),dim=-1)和torch.stack((c,d,e),dim=2)结果是一样的;

('a.shape',
 torch.Size([100]),
 'b.shape',
 torch.Size([100]),
 'ab.shape',
 torch.Size([2, 100]),
 'ab1.shape',
 torch.Size([2, 100]),
 'cd.shape',
 torch.Size([2, 100, 2]),
 'cd1.shape',
 torch.Size([100, 2, 2]),
 'cd2.shape',
 torch.Size([100, 3, 2]),
 torch.Size([100, 2, 3]))

   

 

import torch
# t1=torch.tensor([1,1,1])
# t2=torch.tensor([2,2,2])
# t3=torch.tensor([3,3,3])

f1=torch.tensor([[1,2,3],[4,5,6]])
f2=torch.tensor([[7,8,9],[10,11,12]])
c=torch.tensor([[13,14,15],[16,17,18]])

# a=torch.cat((f1,f2,c),dim=0)
# b=torch.cat((f1,f2,c),1)
# print(a.shape,b.shape,sep='\n')
# d=torch.rand(100,4)
# e=torch.cat((d,c),1)
# print(e.shape)
g=torch.stack((f1,f2,c),0)
g1=torch.stack((f1,f2,c),1)
print(f1.shape,f2.shape,c.shape)
print('g: ',g.shape,g,sep='\n')
print('g1: ',g1.shape,g1,sep='\n')

输出结果为:

torch.Size([2, 3]) torch.Size([2, 3]) torch.Size([2, 3])
g: 
torch.Size([3, 2, 3])#本来3个,就3个;本来2行3列就两行三列;只不过把他们放到一起,变成了3维的,多了一个维度;个人理解,可能有误。
tensor([[[ 1,  2,  3],
         [ 4,  5,  6]],

        [[ 7,  8,  9],
         [10, 11, 12]],

        [[13, 14, 15],
         [16, 17, 18]]])
g1: 
torch.Size([2, 3, 3])#把本来的三个中,每个的第一列拼在一块;第二列拼在一块;再把拼过后的第一列和第二列分别作为一个二维矩阵; 个人理解,可能有误。

tensor([[[ 1,  2,  3],
         [ 7,  8,  9],
         [13, 14, 15]],

        [[ 4,  5,  6],
         [10, 11, 12],
         [16, 17, 18]]])

 

import torch
# t1=torch.tensor([1,1,1])
# t2=torch.tensor([2,2,2])
# t3=torch.tensor([3,3,3])

f1=torch.tensor([[1,2,3],[4,5,6]])
f2=torch.tensor([[7,8,9],[10,11,12]])
c=torch.tensor([[13,14,15],[16,17,18]])

# a=torch.cat((f1,f2,c),dim=0)
# b=torch.cat((f1,f2,c),1)
# print(a.shape,b.shape,sep='\n')
# d=torch.rand(100,4)
# e=torch.cat((d,c),1)
# print(e.shape)
g=torch.stack((f1,f2,c),0)
g1=torch.stack((f1,f2,c),2)
print(f1.shape,f2.shape,c.shape)
print('g: ',g.shape,g,sep='\n')
print('g1: ',g1.shape,g1,sep='\n')

输出结果:
torch.Size([2, 3]) torch.Size([2, 3]) torch.Size([2, 3])
g: 
torch.Size([3, 2, 3])
tensor([[[ 1,  2,  3],
         [ 4,  5,  6]],
        [[ 7,  8,  9],
         [10, 11, 12]],
        [[13, 14, 15],
         [16, 17, 18]]])
g1: 
torch.Size([2, 3, 3])
tensor([[[ 1,  7, 13],
         [ 2,  8, 14],
         [ 3,  9, 15]],
        [[ 4, 10, 16],
         [ 5, 11, 17],
         [ 6, 12, 18]]])

  

import torch
# t1=torch.tensor([1,1,1])
# t2=torch.tensor([2,2,2])
# t3=torch.tensor([3,3,3])

f1=torch.tensor([[1,2,3],[4,5,6]])
f2=torch.tensor([[7,8,9],[10,11,12]])
c=torch.tensor([[13,14,15],[16,17,18]])

# a=torch.cat((f1,f2,c),dim=0)
# b=torch.cat((f1,f2,c),1)
# print(a.shape,b.shape,sep='\n')
# d=torch.rand(100,4)
# e=torch.cat((d,c),1)
# print(e.shape)
g=torch.stack((f1,f2,c),0)
g1=torch.stack((f1,f2,c),3)#此处dim=3,或比3大的任何正数,都是如下报错结果。
print(f1.shape,f2.shape,c.shape)
print('g: ',g.shape,g,sep='\n')
print('g1: ',g1.shape,g1,sep='\n')

输出结果:
Traceback (most recent call last):
  File "<input>", line 17, in <module>
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)

 

此外: 

如果,torch.stack()的维度dim输入的是-1,-2,-3,也都可以正确输出结果。但是如果输入比-3小的任何数则会报错;具体如下:

import torch
# t1=torch.tensor([1,1,1])
# t2=torch.tensor([2,2,2])
# t3=torch.tensor([3,3,3])

f1=torch.tensor([[1,2,3],[4,5,6]])
f2=torch.tensor([[7,8,9],[10,11,12]])
c=torch.tensor([[13,14,15],[16,17,18]])

# a=torch.cat((f1,f2,c),dim=0)
# b=torch.cat((f1,f2,c),1)
# print(a.shape,b.shape,sep='\n')
# d=torch.rand(100,4)
# e=torch.cat((d,c),1)
# print(e.shape)
g=torch.stack((f1,f2,c),0)
g1=torch.stack((f1,f2,c),-1) #此时,维度是-1
print(f1.shape,f2.shape,c.shape)
print('g: ',g.shape,g,sep='\n')
print('g1: ',g1.shape,g1,sep='\n')

输出结果为:
torch.Size([2, 3]) torch.Size([2, 3]) torch.Size([2, 3])
g: 
torch.Size([3, 2, 3])
tensor([[[ 1,  2,  3],
         [ 4,  5,  6]],

        [[ 7,  8,  9],
         [10, 11, 12]],

        [[13, 14, 15],
         [16, 17, 18]]])
g1: 
torch.Size([2, 3, 3])
tensor([[[ 1,  7, 13],
         [ 2,  8, 14],
         [ 3,  9, 15]],

        [[ 4, 10, 16],
         [ 5, 11, 17],
         [ 6, 12, 18]]])

torch.stack()的维度dim输入的是--2;

输出结果为:
torch.Size([2, 3]) torch.Size([2, 3]) torch.Size([2, 3])
g: 
torch.Size([3, 2, 3])
tensor([[[ 1,  2,  3],
         [ 4,  5,  6]],

        [[ 7,  8,  9],
         [10, 11, 12]],

        [[13, 14, 15],
         [16, 17, 18]]])
g1: 
torch.Size([2, 3, 3])
tensor([[[ 1,  2,  3],
         [ 7,  8,  9],
         [13, 14, 15]],

        [[ 4,  5,  6],
         [10, 11, 12],
         [16, 17, 18]]])

torch.stack()的维度dim输入的是-3;

torch.Size([2, 3]) torch.Size([2, 3]) torch.Size([2, 3])
g: 
torch.Size([3, 2, 3])
tensor([[[ 1,  2,  3],
         [ 4,  5,  6]],

        [[ 7,  8,  9],
         [10, 11, 12]],

        [[13, 14, 15],
         [16, 17, 18]]])
g1: 
torch.Size([3, 2, 3])
tensor([[[ 1,  2,  3],
         [ 4,  5,  6]],

        [[ 7,  8,  9],
         [10, 11, 12]],

        [[13, 14, 15],
         [16, 17, 18]]])

 

import torch
# t1=torch.tensor([1,1,1])
# t2=torch.tensor([2,2,2])
# t3=torch.tensor([3,3,3])

f1=torch.tensor([[1,2,3],[4,5,6]])
f2=torch.tensor([[7,8,9],[10,11,12]])
c=torch.tensor([[13,14,15],[16,17,18]])

# a=torch.cat((f1,f2,c),dim=0)
# b=torch.cat((f1,f2,c),1)
# print(a.shape,b.shape,sep='\n')
# d=torch.rand(100,4)
# e=torch.cat((d,c),1)
# print(e.shape)
g=torch.stack((f1,f2,c),0)
g1=torch.stack((f1,f2,c),-4)#此处是dim=-4,小于-4的任何负数,输出类似的结果。
print(f1.shape,f2.shape,c.shape)
print('g: ',g.shape,g,sep='\n')
print('g1: ',g1.shape,g1,sep='\n')

输出结果为:
Traceback (most recent call last):
  File "<input>", line 17, in <module>
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got -4)

 

import torch
a=torch.rand(100)
b=torch.rand(100)

c=torch.stack((a,b))
d=torch.stack((a,b),dim=0)
e=torch.stack((a,b),dim=1)

f=torch.cat((a,b))
f1=torch.cat((a,b),dim=0)
# f2=torch.cat((a,b),dim=1)#错误提示:Dimension out of range (expected to be in range of [-1, 0], but got 1)
c.size,c.size(),c.shape,d.shape,e.shape,f.shape,f.size(),f1.shape
#从输出结果可看出,torch.rand(100)生成的是100行1列的数据,也即是一个列向量;concat沿着已有的维度拼接,stack在新创建的维度上拼接;
#输出:
(<function Tensor.size>,
 torch.Size([2, 100]),
 torch.Size([2, 100]),
 torch.Size([2, 100]),
 torch.Size([100, 2]),
 torch.Size([200]),
 torch.Size([200]),
 torch.Size([200]))

  

 

posted on 2021-08-20 10:45  lmqljt  阅读(508)  评论(0编辑  收藏  举报

导航