PyTorch 两大转置函数 transpose() 和 permute(),

pytorch中转置用的函数就只有这两个

  1. transpose()
  2. permute()

 

transpose()
torch.transpose(input, dim0, dim1, out=None) → Tensor

函数返回输入矩阵input的转置。交换维度dim0dim1

参数:

  • input (Tensor) – 输入张量,必填
  • dim0 (int) – 转置的第一维,默认0,可选
  • dim1 (int) – 转置的第二维,默认1,可选

注意只能有两个相关的交换的位置参数。

permute()

 

参数:

dims (int…*)-换位顺序,必填

相同点

  1. 都是返回转置后矩阵。
  2. 都可以操作高纬矩阵,permute在高维的功能性更强。
复制代码
# 创造二维数据x,dim=0时候2,dim=1时候3
x = torch.randn(2,3)       'x.shape  →  [2,3]'
# 创造三维数据y,dim=0时候2,dim=1时候3,dim=2时候4
y = torch.randn(2,3,4)   'y.shape  →  [2,3,4]'
复制代码
# 对于transpose
x.transpose(0,1)     'shape→[3,2] '  
x.transpose(1,0)     'shape→[3,2] '  
y.transpose(0,1)     'shape→[3,2,4]' 
y.transpose(0,2,1)  'error,操作不了多维'

# 对于permute()
x.permute(0,1)     'shape→[2,3]'
x.permute(1,0)     'shape→[3,2], 注意返回的shape不同于x.transpose(1,0) '
y.permute(0,1)     "error 没有传入所有维度数"
y.permute(1,0,2)  'shape→[3,2,4]'
复制代码

合法性不同
torch.transpose(x)合法, x.transpose()合法。
tensor.permute(x)不合法,x.permute()合法。

参考第二点的举例

操作dim不同:
transpose()只能一次操作两个维度;permute()可以一次操作多维数据,且必须传入所有维度数,因为permute()的参数是int*。
复制代码
  1. transpose()中的dim没有数的大小区分;permute()中的dim有数的大小区分

举例,注意后面的shape

 

复制代码
# 对于transpose,不区分dim大小
x1 = x.transpose(0,1)   'shape→[3,2] '  
x2 = x.transpose(1,0)   '也变换了,shape→[3,2] '  
print(torch.equal(x1,x2))
' True ,value和shape都一样'

# 对于permute()
x1 = x.permute(0,1)     '不同transpose,shape→[2,3] '  
x2 = x.permute(1,0)     'shape→[3,2] '  
print(torch.equal(x1,x2))
'False,和transpose不同'

y1 = y.permute(0,1,2)     '保持不变,shape→[2,3,4] '  
y2 = y.permute(1,0,2)     'shape→[3,2,4] '  
y3 = y.permute(1,2,0)     'shape→[3,4,2] '  
复制代码

 

view()函数改变通过转置后的数据结构,导致报错
RuntimeError: invalid argument 2: view size is not compatible with input tensor's....

这是因为tensor经过转置后数据的内存地址不连续导致的,也就是tensor . is_contiguous()==False
虽然在torch里面,view函数相当于numpy的reshape,但是这时候reshape()可以改变该tensor结构,但是view()不可以

复制代码
x = torch.rand(3,4)
x = x.transpose(0,1)
print(x.is_contiguous()) # 是否连续
'False'
# 会发现
x.view(3,4)
'''
RuntimeError: invalid argument 2: view size is not compatible with input tensor's....
就是不连续导致的
'''
# 但是这样是可以的。
x = x.contiguous()
x.view(3,4)
复制代码

 

复制代码
x = torch.rand(3,4)
x = x.permute(1,0) # 等价x = x.transpose(0,1)
x.reshape(3,4)
'''这就不报错了
说明x.reshape(3,4) 这个操作
等于x = x.contiguous().view()
尽管如此,但是我们还是不推荐使用reshape
除非为了获取完全不同但是数据相同的克隆体
'''
复制代码

调用contiguous()时,会强制拷贝一份tensor,让它的布局和从头创建的一毛一样。

只需要记住了,每次在使用view()之前,该tensor只要使用了transpose()permute()这两个函数一定要contiguous().

 

posted on   cltt  阅读(15093)  评论(0编辑  收藏  举报

编辑推荐:
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
阅读排行:
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 单线程的Redis速度为什么快?
· SQL Server 2025 AI相关能力初探
· AI编程工具终极对决:字节Trae VS Cursor,谁才是开发者新宠?
· 展开说说关于C#中ORM框架的用法!
历史上的今天:
2019-08-22 输出细节
2018-08-22 基本大数问题
< 2025年3月 >
23 24 25 26 27 28 1
2 3 4 5 6 7 8
9 10 11 12 13 14 15
16 17 18 19 20 21 22
23 24 25 26 27 28 29
30 31 1 2 3 4 5

导航

统计

点击右上角即可分享
微信分享提示