pytorch 矩阵批量转置(torch.permute)

在某些情况下,如果你想对一个被包起来的二维数组集合里面的每一个二维数组做转置,那么就可以用torch.permute这个函数,其能够实现批量转置操作,现在让我们来看看这个函数中的维度变换的原理。

比如对于一个三维矩阵:

M = torch.tensor([
   [[2, 5],
    [3, 4]],

   [[2, 5],
   [3, 4]]])       
 

我们想对其中的每一个二维矩阵做转置操作,那么我么可以这么做:

torch.permute(M,(0,2,1))

结果如下:

tensor([[
    [2, 3],
    [5, 4]],

    [[2, 3],
    [5, 4]]])

可见,每一个二位数组都被转置了过来。

其实,permute函数中的(0,2,1)这个参数的含义就是把要操作的数组中的每个元素的坐标换成(0,2,1)的形式,比如M中的第一个二维数组中的‘3’的坐标为(0,1,0),坐标维度顺序原本为(0,1,2),那么在permute操作之后,顺序变成(0,2,1)那么这个'3'的坐标就变成了(0,1,2)。结果和上面的操作是一致的。

参考链接:

https://blog.csdn.net/qq_41740004/article/details/104712173

 

posted @ 2022-07-06 16:08  Hisi  阅读(1743)  评论(0编辑  收藏  举报