Pytorch-拼接与拆分
引言
拼接与拆分
- cat
- stack
- split
- chunk
cat
- numpy中使用concat,在pytorch中使用更加简写的 cat
- 完成一个拼接
- 两个向量维度相同,想要拼接的维度上的值可以不同,但是其它维度上的值必须相同。
举个例子:还是按照前面的,想将这两组班级的成绩合并起来
a[class 1-4, students, scores]
b[class 5-9, students, scores]
1 |
In[4]: a = torch.rand(4,32,8) |
理解cat:
- 行拼接:[4, 4] 与 [5, 4] 以 dim=0(行)进行拼接 —> [9, 4] 9个班的成绩合起来
- 列拼接:[4, 5] 与 [4, 3] 以 dim=1(列)进行拼接 —> [4, 8] 每个班合成8项成绩
例2:
1 |
In[7]: a1 = torch.rand(4,3,32,32) |
stack
- 创造一个新的维度(代表了新的组别)
- 要求两个tensor的size完全相同
1 |
In[19]: a1 = torch.rand(4,3,16,32) |
split
- 按长度进行拆分:单元长度/数量
- 长度相同给一个固定值
- 长度不同给一个列表
1 |
In[48]: a = torch.rand(32,8) |
chunk
- 按数量进行拆分
1 |
In[63]: s.shape |
note:对于按数量切分:chunk中的参数是要切成几份;split的常数是每份有几个。