pytorch-合并与分割
Merge or split
▪ Cat(合并)
▪ Stack(合并)
▪ Split(拆分)
▪ Chunk(拆分)
合并
cat
这个就是合并两个tensor
比如说有两个班级的成绩单,一个是1-4班的,一个是5-9班的,我们现在需要合并这两份成绩单。
▪ Statistics about scores
▪ [class1-4, students, scores]
▪ [class5-9, students, scores]
torch.cat([a,b],dim)
这个就是合并a,b两个tensor再第dim个维度上,需要注意的是除了dim这个维度,剩下的维度都要shape相等
a=torch.rand(4,32,8)
b=torch.rand(5,32,8)
torch.cat([a,b],dim=0).shape
# torch.Size([9, 32, 8])
a1=torch.rand(4,3,32,32)
a2=torch.rand(5,3,32,32)
torch.cat([a1,a2],dim=0).shape
# torch.Size([9, 3, 32, 32])
如果
a1=torch.rand(4,3,32,32)
a2=torch.rand(4,1,32,32)
torch.cat([a1,a2],dim=0).shape
由于除了dim=0以外dim=1的维度不相同,所以不行
但是假如说
torch.cat([a1,a2],dim=1).shape
# torch.Size([4, 4, 32, 32])
a1=torch.rand(4,3,16,32)
a2=torch.rand(4,3,16,32)
torch.cat([a1,a2],dim=2).shape
# torch.Size([4, 3, 32, 32])
# 这个时候就是两张图片,上下拼接上
stack
这个也是合并
stack([a,b],dim)
不过这个合并和上一个不一样,这个合并会创造一个新的维度,比如[32,8]和[32,8],在dim=0维度进行合并的话是[2,32,8]。然后res[0,:,:]是第一个,res[1,:,:]是第二个。
a1=torch.rand(4,3,16,32)
a2=torch.rand(4,3,16,32)
torch.cat([a1,a2],dim=2).shape
# torch.Size([4, 3, 32, 32])
torch.stack([a1,a2],dim=2).shape
# torch.Size([4, 3, 2, 16, 32])
然后我们下面有一个具体的场景,就是一共有俩个班级,每个班级一共有32个学生,每个学生有8门课,进行合并。这个时候我们就不能利用cat了,因为cat合并的结果为[64,8]。而stack合并的结果为[2,32,8],这样更符合要求
a=torch.rand(32,8)
b=torch.rand(32,8)
torch.stack([a,b],dim=0).shape
# torch.Size([2, 32, 8])
但是需要注意的是除了要合并的那个维度,其余的维度都要相等。
拆分
split
这个是通过长度进行拆分的。
.split([,],dim)
[,]里面的参数是最终的得到的tensor的dim上的维度,比如说c=[3,32,8],
aa,bb=c.split([2,1],dim=0),
那么aa.shape=[2,32,8],bb.shape=[1,32,8]。返回的该维度上的值的和要等于目标函数在该维度上的值
或者说也可以直接给顶一个len,这个就是返回的最终函数的维度都一样,比如说c=[2,32,8],也可以用c.split(1,dim=0).
c=torch.randn(2,32,8)
a,b=c.split(1,dim=0)
a.shape,b.shape
#(torch.Size([1, 32, 8]), torch.Size([1, 32, 8]))
cc=torch.rand(3,32,8)
aa,bb=cc.split([2,1],dim=0)
aa.shape,bb.shape
#(torch.Size([2, 32, 8]), torch.Size([1, 32, 8]))
但是需要注意的是这个函数只能拆分成两个,不如说c=[3,32,8],aa,bb,cc=c.split(1,dim=0)这个是不行的
chunk
这个是按照数量进行拆分,这个可以返回多个
---=.chunk(num,dim)
这个就是最终拆分成num个,比如说[6,32,8],num=2,那就是最终拆分成两个,返回[3,32,8],[3,32,8]、如果num=3,那就是最终拆分成三个,返回[2,32,8],[2,32,8],[2,32,8]
c=torch.randn(6,32,8)
aa,bb=c.chunk(2,dim=0)
aa.shape,bb.shape
# (torch.Size([3, 32, 8]), torch.Size([3, 32, 8]))
aaa,bbb,ccc=c.chunk(3,dim=0)
aaa.shape,bbb.shape,ccc.shape
# (torch.Size([2, 32, 8]), torch.Size([2, 32, 8]), torch.Size([2, 32, 8]))