pytorch permute
pytorch permute
permute(dims)
将tensor的维度换位。
参数:参数是一系列的整数,代表原来张量的维度。比如三维就有0,1,2这些dimension。
例:
import torch
import numpy as np
a=np.array([[[1,2,3],[4,5,6]]])
unpermuted=torch.tensor(a)
print(unpermuted.size()) # ——> torch.Size([1, 2, 3])
permuted=unpermuted.permute(2,0,1)
print(permuted.size()) # ——> torch.Size([3, 1, 2])
再比如图片img的size比如是(28,28,3)就可以利用img.permute(2,0,1)得到一个size为(3,28,28)的tensor。
利用这个函数permute(0,2,1)可以把Tensor([[[1,2,3],[4,5,6]]]) 转换成
tensor([[[1., 4.],
[2., 5.],
[3., 6.]]])
如果使用view,可以得到
tensor([[[1., 2.],
[3., 4.],
[5., 6.]]])
关于view的用法:参见PyTorch中view的用法
链接:https://blog.csdn.net/york1996/article/details/81876886