PyTorch 两大转置函数 transpose() 和 permute(),
在pytorch
中转置用的函数就只有这两个
transpose()
permute()
transpose()
torch.transpose(input, dim0, dim1, out=None) → Tensor
函数返回输入矩阵input
的转置。交换维度dim0
和dim1
参数:
- input (Tensor) – 输入张量,必填
- dim0 (int) – 转置的第一维,默认0,可选
- dim1 (int) – 转置的第二维,默认1,可选
注意只能有两个相关的交换的位置参数。
permute()
参数:
dims (int…*)-换位顺序,必填
相同点
- 都是返回转置后矩阵。
- 都可以操作高纬矩阵,
permute
在高维的功能性更强。
# 创造二维数据x,dim=0时候2,dim=1时候3 x = torch.randn(2,3) 'x.shape → [2,3]' # 创造三维数据y,dim=0时候2,dim=1时候3,dim=2时候4 y = torch.randn(2,3,4) 'y.shape → [2,3,4]'
# 对于transpose x.transpose(0,1) 'shape→[3,2] ' x.transpose(1,0) 'shape→[3,2] ' y.transpose(0,1) 'shape→[3,2,4]' y.transpose(0,2,1) 'error,操作不了多维' # 对于permute() x.permute(0,1) 'shape→[2,3]' x.permute(1,0) 'shape→[3,2], 注意返回的shape不同于x.transpose(1,0) ' y.permute(0,1) "error 没有传入所有维度数" y.permute(1,0,2) 'shape→[3,2,4]'
合法性不同
torch.transpose(x)合法, x.transpose()合法。
tensor.permute(x)不合法,x.permute()合法。
参考第二点的举例
操作dim不同:
transpose()只能一次操作两个维度;permute()可以一次操作多维数据,且必须传入所有维度数,因为permute()的参数是int*。
transpose()
中的dim
没有数的大小区分;permute()
中的dim
有数的大小区分
举例,注意后面的shape
:
# 对于transpose,不区分dim大小 x1 = x.transpose(0,1) 'shape→[3,2] ' x2 = x.transpose(1,0) '也变换了,shape→[3,2] ' print(torch.equal(x1,x2)) ' True ,value和shape都一样' # 对于permute() x1 = x.permute(0,1) '不同transpose,shape→[2,3] ' x2 = x.permute(1,0) 'shape→[3,2] ' print(torch.equal(x1,x2)) 'False,和transpose不同' y1 = y.permute(0,1,2) '保持不变,shape→[2,3,4] ' y2 = y.permute(1,0,2) 'shape→[3,2,4] ' y3 = y.permute(1,2,0) 'shape→[3,4,2] '
用view()
函数改变通过转置后的数据结构,导致报错RuntimeError: invalid argument 2: view size is not compatible with input tensor's....
这是因为tensor经过转置后数据的内存地址不连续导致的,也就是tensor . is_contiguous()==False
虽然在torch里面,view函数相当于numpy的reshape,但是这时候reshape()
可以改变该tensor结构,但是view()
不可以
x = torch.rand(3,4) x = x.transpose(0,1) print(x.is_contiguous()) # 是否连续 'False' # 会发现 x.view(3,4) ''' RuntimeError: invalid argument 2: view size is not compatible with input tensor's.... 就是不连续导致的 ''' # 但是这样是可以的。 x = x.contiguous() x.view(3,4)
x = torch.rand(3,4) x = x.permute(1,0) # 等价x = x.transpose(0,1) x.reshape(3,4) '''这就不报错了 说明x.reshape(3,4) 这个操作 等于x = x.contiguous().view() 尽管如此,但是我们还是不推荐使用reshape 除非为了获取完全不同但是数据相同的克隆体 '''
调用contiguous()
时,会强制拷贝一份tensor
,让它的布局和从头创建的一毛一样。
只需要记住了,每次在使用view()
之前,该tensor
只要使用了transpose()
和permute()
这两个函数一定要contiguous()
.