tensor的拼接与拆分

tensor的拼接与拆分

cat函数

例子:成绩单的合并

【班级1~4 学生 得分】

【班级5~9 学生 得分】

在0维进行合并,非cat的维度必须一致

a = torch.rand(4,32,8)
b = torch.rand(5,32,8)
c = torch.cat([a,b],dim=0)
c.shape()
#[9,32,8]

stack函数

会新添加一个维度,要保证两个stack的tensor的维度一摸一样,在理解方面是添加了新的概念在里面。

例子:

一班:【32个学生 每个学生8门课程】

二班:【32个学生 每个学生8门课程】

stack之后变为【两个班级 每个班级32个学生 每个学生有8门课程】

a = torch.rand(32,8)
b = torch.rand(32,8)
torch.stack([a,b],dim=0).shape
#[2 32 8]

split函数

split函数按照长度来拆分

例子1:

参数说明:【1,1】表示前面的长度为1,后面的长度也是1

a = torch.rand(2,32,8)
b,c = torch.split([1,1],dim=0)
b.shape
#[1,32,8]
c.shape()
#[1,32,8]

例子2:

参数说明:【2,1】表示前面的长度为2,后面的长度为1(不规则分割的参数含义)

a = torch.rand(3,32,8)
b,c = torch.split([2,1],dim=0)
b.shape
#[2,32,8]
c.shape()
#[1,32,8]

chunk函数

根据数量来进行分割(尽量实现整除,后面除不尽的留给最后)

例子:

a = torch.rand(6,32,8)
b,c,d= torch.chunk(a,3,dim=0)
print(b.shape)
print(c.shape)
print(d.shape)

#torch.Size([2, 32, 8])
#torch.Size([2, 32, 8])
#torch.Size([2, 32, 8])

例子2:

a = torch.rand(5,32,8)
b,c,d= torch.chunk(a,3,dim=0)
print(b.shape)
print(c.shape)
print(d.shape)

#torch.Size([2, 32, 8])
#torch.Size([2, 32, 8])
#torch.Size([1, 32, 8])

例子3:

a = torch.rand(5,32,8)
b,c= torch.chunk(a,2,dim=0)
print(b.shape)
print(c.shape)

#torch.Size([3, 32, 8])
#torch.Size([2, 32, 8])
posted @ 2020-09-01 18:34  Jason66661010  阅读(1059)  评论(0编辑  收藏  举报