pytorch入门--拆分与拼接

其他相关操作:https://blog.csdn.net/qq_43923588/article/details/108007534

本篇pytorch的tensor拆分与拼接进行展示,包含:

  • cat
  • stack
  • split
  • chunk

使用方法和含义均在代码的批注中给出,因为有较多的输出,所以设置输出内容的第一个值为当前print()方法所在的行

拆分与拼接

import torch
import numpy as np
import sys
loc = sys._getframe()
_ = '\n'


'''
cat
第一个参数为一个list,其中包含所有的要用于合并的tensor
第二个参数决定在哪一个维度上进行合并,其余的维度需要保证相同
'''
a = torch.rand(4, 32, 8)
b = torch.rand(5, 32, 8)
print(loc.f_lineno, _, torch.cat([a, b], dim=0).shape)


'''
stack
必须保证合并的tensor维度完全相同
使用stack进行tensor合并,则在其要合并的维度前面添加一个维度
这个新添加的维度为要合并的tensor的数量
当其取1时,则合并后的tensor等于第一个tensor
当其取2时,则合并后的tensor等于第二个tensor
依次
相当于对原来每个tensor的选择
'''
c = torch.rand(2, 2, 16, 16)
d = torch.rand(2, 2, 16, 16)
e = torch.rand(2, 2, 16, 16)
print(loc.f_lineno, _, torch.stack([c, d, e], dim=2).shape)


'''
split
可以完成根据元素数量拆分和根据长度拆分
'''
f = torch.rand(4, 32, 5)
# 根据第一个维度进行拆分
g, gg, ggg = f.split([1, 1, 2], dim=0)
print(loc.f_lineno, _, g.shape, _, gg.shape, _, ggg.shape)
# 根据长度进行拆分,将的一个维度等分为两份
h, hh = f.split(2, dim=0)
print(loc.f_lineno, _, h.shape, _, hh.shape)


'''
chunk
使用方法类似于  h, hh = f.split(2, dim=0)
'''
i, ii = f.chunk(2, dim=0)
print(loc.f_lineno, _, i.shape, _, ii.shape)

输出结果

15 
 torch.Size([9, 32, 8])
31 
 torch.Size([2, 2, 3, 16, 16])
41 
 torch.Size([1, 32, 5]) 
 torch.Size([1, 32, 5]) 
 torch.Size([2, 32, 5])
44 
 torch.Size([2, 32, 5]) 
 torch.Size([2, 32, 5])
52 
 torch.Size([2, 32, 5]) 
 torch.Size([2, 32, 5])

Process finished with exit code 0

posted @ 2020-08-14 15:53  博0_oer~  阅读(31)  评论(0编辑  收藏  举报