transpose()和permute()
在pytorch
中转置用的函数就只有这两个:transpose()和
permute(),本文将详细地介绍这两个函数以及它们之间的区别。
transpose()
torch.transpose(input, dim0, dim1, out=None) → Tensor
函数返回输入矩阵input
的转置。交换维度dim0
和dim1
参数:
- input (Tensor) – 输入张量,必填
- dim0 (int) – 转置的第一维,默认0,可选
- dim1 (int) – 转置的第二维,默认1,可选
注意只能有两个相关的交换的位置参数。
例子:
>>> x = torch.randn(2, 3) >>> x tensor([[ 1.0028, -0.9893, 0.5809], [-0.1669, 0.7299, 0.4942]]) >>> torch.transpose(x, 0, 1) tensor([[ 1.0028, -0.1669], [-0.9893, 0.7299], [ 0.5809, 0.4942]])
permute()
参数: dims (int…*)-换位顺序,必填
例子:
>>> x = torch.randn(2, 3, 5) >>> x.size() torch.Size([2, 3, 5]) >>> x.permute(2, 0, 1).size() torch.Size([5, 2, 3])
transpose与permute的异同
- permute相当于可以同时操作于tensor的若干维度,transpose只能同时作用于tensor的两个维度;
- torch.transpose(x)合法, x.transpose()合法。torch.permute(x)不合法,x.permute()合法。
- 与contiguous、view函数之关联。contiguous:view只能作用在contiguous的variable上,如果在view之前调用了transpose、permute等,就需要调用contiguous()来返回一个contiguous copy;一种可能的解释是:有些tensor并不是占用一整块内存,而是由不同的数据块组成,而tensor的view()操作依赖于内存是整块的,这时只需要执行contiguous()这个函数,把tensor变成在内存中连续分布的形式;判断ternsor是否为contiguous,可以调用torch.Tensor.is_contiguous()函数:
import torch x = torch.ones(10, 10) x.is_contiguous() # True x.transpose(0, 1).is_contiguous() # False x.transpose(0, 1).contiguous().is_contiguous() # True
另:在pytorch的最新版本0.4版本中,增加了torch.reshape(),与 numpy.reshape() 的功能类似,大致相当于 tensor.contiguous().view(),这样就省去了对tensor做view()变换前,调用contiguous()的麻烦;