pytorch入门--拆分与拼接
其他相关操作:https://blog.csdn.net/qq_43923588/article/details/108007534
本篇pytorch的tensor拆分与拼接进行展示,包含:
- cat
- stack
- split
- chunk
使用方法和含义均在代码的批注中给出,因为有较多的输出,所以设置输出内容的第一个值为当前print()方法所在的行
拆分与拼接
import torch
import numpy as np
import sys
loc = sys._getframe()
_ = '\n'
'''
cat
第一个参数为一个list,其中包含所有的要用于合并的tensor
第二个参数决定在哪一个维度上进行合并,其余的维度需要保证相同
'''
a = torch.rand(4, 32, 8)
b = torch.rand(5, 32, 8)
print(loc.f_lineno, _, torch.cat([a, b], dim=0).shape)
'''
stack
必须保证合并的tensor维度完全相同
使用stack进行tensor合并,则在其要合并的维度前面添加一个维度
这个新添加的维度为要合并的tensor的数量
当其取1时,则合并后的tensor等于第一个tensor
当其取2时,则合并后的tensor等于第二个tensor
依次
相当于对原来每个tensor的选择
'''
c = torch.rand(2, 2, 16, 16)
d = torch.rand(2, 2, 16, 16)
e = torch.rand(2, 2, 16, 16)
print(loc.f_lineno, _, torch.stack([c, d, e], dim=2).shape)
'''
split
可以完成根据元素数量拆分和根据长度拆分
'''
f = torch.rand(4, 32, 5)
# 根据第一个维度进行拆分
g, gg, ggg = f.split([1, 1, 2], dim=0)
print(loc.f_lineno, _, g.shape, _, gg.shape, _, ggg.shape)
# 根据长度进行拆分,将的一个维度等分为两份
h, hh = f.split(2, dim=0)
print(loc.f_lineno, _, h.shape, _, hh.shape)
'''
chunk
使用方法类似于 h, hh = f.split(2, dim=0)
'''
i, ii = f.chunk(2, dim=0)
print(loc.f_lineno, _, i.shape, _, ii.shape)
输出结果
15
torch.Size([9, 32, 8])
31
torch.Size([2, 2, 3, 16, 16])
41
torch.Size([1, 32, 5])
torch.Size([1, 32, 5])
torch.Size([2, 32, 5])
44
torch.Size([2, 32, 5])
torch.Size([2, 32, 5])
52
torch.Size([2, 32, 5])
torch.Size([2, 32, 5])
Process finished with exit code 0