pytorch入门--维度变换

其他相关操作:https://blog.csdn.net/qq_43923588/article/details/108007534

本篇pytorch的维度变换进行展示,包含:

  • view/reshape
  • squeeze/unsqueeze
  • expand/repeat
  • transpose/t/permute
  • broadcast

使用方法和含义均在代码的批注中给出,因为有较多的输出,所以设置输出内容的第一个值为当前print()方法所在的行

维度变换

import torch
import numpy as np
import sys
loc = sys._getframe()
_ = '\n'


'''
view/reshape改变tensor形状
view/reshape使用方法几乎完全一样,reshape是为了使用numpy进行增加的方法
'''
a = torch.rand(4, 1, 28, 28)
print(loc.f_lineno, _, a.shape)
# 使用view对a的后三个维度进行合并,合并的维度必须满足物理意义,即维度相乘
print(loc.f_lineno, _, a.view(4, 1*28*28), _, a.view(4, 1*28*28).shape)

# 具体的物理意义要在使用时确定,比如下面对图像进行叠加
print(loc.f_lineno, _, a.view(4*1, 28, 28), _, a.view(4*1, 28, 28).shape)

# 存在的问题:对矩阵进行维度变化后,恢复以后可能不具有物理意义
b = a.view(4, 1*28*28)
print(loc.f_lineno, _, b.view(1, 4, 28, 28), _, a.view(1, 4, 28, 28).shape)


'''squeeze/unsqueeze维度删减和增加'''
b = torch.rand(4, 1, 28, 28)
# 使用unsqueeze()进行维度扩充,在b的0位置插入一个维度
print(loc.f_lineno, _, b.unsqueeze(0).shape)
# 在b的最后位置插入一个维度
print(loc.f_lineno, _, b.unsqueeze(-1).shape)
print(loc.f_lineno, _, b.unsqueeze(4).shape)
# 从后向前插入,unsqueeze()的参数取值范围为:维度x,则取值[-x-1, x+1) 不包含右边界
print(loc.f_lineno, _, b.unsqueeze(-4).shape)
print(loc.f_lineno, _, b.unsqueeze(-5).shape)

# 例:将bb的维度扩充为bbb的维度
bb = torch.rand(32)
bbb = torch.rand(4, 32, 14, 14)
bb = bb.unsqueeze(1).unsqueeze(2).unsqueeze(0)
print(loc.f_lineno, _, bb.shape, _, bbb.shape)

'''使用squeeze进行维度删减,需要指定index参数,若没有指定参数,则将所有为1的维度进行删减'''
print(loc.f_lineno, _, b.shape)
print(loc.f_lineno, _, b.squeeze().shape)
c = torch.rand(1, 1, 20, 20)
print(loc.f_lineno, _, c.shape, _, c.squeeze(-4).shape)
# 当对不能进行压缩的维度进行压缩时,原tensor不会改变
print(loc.f_lineno, _, c.shape, _, c.squeeze(2).shape)


'''
expand/repeat矩阵维度扩展
使用expand进行维度扩充,不会主动取增加数据
使用repeat进行维度扩充,会把某一维度的数据在扩充时赋值到新扩充的维度
'''
# 可以方便矩阵进行运算
d = torch.rand(1, 4, 1, 1)
e = torch.rand(2, 4, 2, 2)
print(loc.f_lineno, _, d, _, e)
# 此时d和e无法直接进行矩阵运算,需要将d扩展维度与e相同
# 使用expand进行扩张,必须时相同维度的扩展,其次初始维度必须为1
print(loc.f_lineno, _, d.expand(2, 4, 2, 2), _, d.shape)
# 扩张时可以使用参数-1是的原来的维度信息保持不变
print(loc.f_lineno, _, d.expand(-1, 4, 2, -1), _, d.shape)

# repeat进行主动复制内存数据,需保证维度一样
f = torch.rand(1, 4, 2, 2)
# repeat的参数为每一个维度要拷贝的次数,对应维度乘以相应次数
print(loc.f_lineno, _, f.repeat(4, 4, 4, 4).shape)


'''transpose/t/permute矩阵转置,行列交换'''
g = torch.randn(3, 4)
# 使用t()函数进行矩阵转置
print(loc.f_lineno, _, g, _, g.t())
# 使用transpose()函数进行指定维度的互换
h = torch.randn(1, 2, 4, 4)
print(loc.f_lineno, _, h.shape, _, h.transpose(0, 3).shape)
# 使用permute进行矩阵维度转换,需要对每个维度指定对应的维度变换次序
i = torch.randn(1, 2, 3, 4)
print(loc.f_lineno, _, i.shape, _, i.permute(3, 2, 1, 0).shape)


'''
Broadcast维度扩张
在进行矩阵加操作时,需要对齐维度上每一个值,
使用Broadcast可以将某一矩阵扩张成与其相加的矩阵的最大维度
'''
j = torch.randn(2, 4)
k = torch.randn(1)
print(loc.f_lineno, _, j, _, k)
# 将k的维度扩充为j的维度
print(loc.f_lineno, _, torch.broadcast_tensors(j, k))

输出结果

13 
 torch.Size([4, 1, 28, 28])
15 
 tensor([[0.1362, 0.0844, 0.9731,  ..., 0.8592, 0.0862, 0.5167],
        [0.4659, 0.3503, 0.6879,  ..., 0.5549, 0.0063, 0.6218],
        [0.6778, 0.3196, 0.4582,  ..., 0.2253, 0.0280, 0.6639],
        [0.7047, 0.3789, 0.0595,  ..., 0.4078, 0.8520, 0.7480]]) 
 torch.Size([4, 784])
18 
 tensor([[[0.1362, 0.0844, 0.9731,  ..., 0.8755, 0.0334, 0.4802],
         [0.1612, 0.9726, 0.8609,  ..., 0.4897, 0.9330, 0.1100],
         [0.7621, 0.3461, 0.8166,  ..., 0.2286, 0.9856, 0.8432],
         ...,
         [0.4147, 0.7541, 0.4794,  ..., 0.6909, 0.8697, 0.9461],
         [0.9182, 0.1205, 0.9826,  ..., 0.9463, 0.4773, 0.8543],
         [0.3497, 0.5007, 0.5041,  ..., 0.8592, 0.0862, 0.5167]],

        [[0.4659, 0.3503, 0.6879,  ..., 0.7880, 0.4007, 0.8897],
         [0.0991, 0.3669, 0.8622,  ..., 0.2973, 0.4850, 0.9193],
         [0.0536, 0.3126, 0.0984,  ..., 0.7140, 0.0592, 0.7686],
         ...,
         [0.9857, 0.7800, 0.6523,  ..., 0.2840, 0.1137, 0.1010],
         [0.8461, 0.7320, 0.0499,  ..., 0.9291, 0.8517, 0.6766],
         [0.4217, 0.2075, 0.4141,  ..., 0.5549, 0.0063, 0.6218]],

        [[0.6778, 0.3196, 0.4582,  ..., 0.5867, 0.5059, 0.9667],
         [0.8183, 0.5885, 0.6947,  ..., 0.6772, 0.3649, 0.2795],
         [0.8529, 0.0812, 0.8268,  ..., 0.7841, 0.2325, 0.1840],
         ...,
         [0.0846, 0.4687, 0.3338,  ..., 0.9408, 0.0029, 0.9178],
         [0.1213, 0.2351, 0.6759,  ..., 0.3001, 0.5185, 0.9210],
         [0.7112, 0.8535, 0.1676,  ..., 0.2253, 0.0280, 0.6639]],

        [[0.7047, 0.3789, 0.0595,  ..., 0.2555, 0.0403, 0.8277],
         [0.4509, 0.0456, 0.6582,  ..., 0.1604, 0.4694, 0.0943],
         [0.7096, 0.7146, 0.0102,  ..., 0.0807, 0.7584, 0.7038],
         ...,
         [0.5749, 0.1963, 0.6901,  ..., 0.9132, 0.3689, 0.2546],
         [0.1389, 0.4381, 0.5972,  ..., 0.1258, 0.7157, 0.4518],
         [0.9226, 0.3656, 0.0768,  ..., 0.4078, 0.8520, 0.7480]]]) 
 torch.Size([4, 28, 28])
22 
 tensor([[[[0.1362, 0.0844, 0.9731,  ..., 0.8755, 0.0334, 0.4802],
          [0.1612, 0.9726, 0.8609,  ..., 0.4897, 0.9330, 0.1100],
          [0.7621, 0.3461, 0.8166,  ..., 0.2286, 0.9856, 0.8432],
          ...,
          [0.4147, 0.7541, 0.4794,  ..., 0.6909, 0.8697, 0.9461],
          [0.9182, 0.1205, 0.9826,  ..., 0.9463, 0.4773, 0.8543],
          [0.3497, 0.5007, 0.5041,  ..., 0.8592, 0.0862, 0.5167]],

         [[0.4659, 0.3503, 0.6879,  ..., 0.7880, 0.4007, 0.8897],
          [0.0991, 0.3669, 0.8622,  ..., 0.2973, 0.4850, 0.9193],
          [0.0536, 0.3126, 0.0984,  ..., 0.7140, 0.0592, 0.7686],
          ...,
          [0.9857, 0.7800, 0.6523,  ..., 0.2840, 0.1137, 0.1010],
          [0.8461, 0.7320, 0.0499,  ..., 0.9291, 0.8517, 0.6766],
          [0.4217, 0.2075, 0.4141,  ..., 0.5549, 0.0063, 0.6218]],

         [[0.6778, 0.3196, 0.4582,  ..., 0.5867, 0.5059, 0.9667],
          [0.8183, 0.5885, 0.6947,  ..., 0.6772, 0.3649, 0.2795],
          [0.8529, 0.0812, 0.8268,  ..., 0.7841, 0.2325, 0.1840],
          ...,
          [0.0846, 0.4687, 0.3338,  ..., 0.9408, 0.0029, 0.9178],
          [0.1213, 0.2351, 0.6759,  ..., 0.3001, 0.5185, 0.9210],
          [0.7112, 0.8535, 0.1676,  ..., 0.2253, 0.0280, 0.6639]],

         [[0.7047, 0.3789, 0.0595,  ..., 0.2555, 0.0403, 0.8277],
          [0.4509, 0.0456, 0.6582,  ..., 0.1604, 0.4694, 0.0943],
          [0.7096, 0.7146, 0.0102,  ..., 0.0807, 0.7584, 0.7038],
          ...,
          [0.5749, 0.1963, 0.6901,  ..., 0.9132, 0.3689, 0.2546],
          [0.1389, 0.4381, 0.5972,  ..., 0.1258, 0.7157, 0.4518],
          [0.9226, 0.3656, 0.0768,  ..., 0.4078, 0.8520, 0.7480]]]]) 
 torch.Size([1, 4, 28, 28])
28 
 torch.Size([1, 4, 1, 28, 28])
30 
 torch.Size([4, 1, 28, 28, 1])
31 
 torch.Size([4, 1, 28, 28, 1])
33 
 torch.Size([4, 1, 1, 28, 28])
34 
 torch.Size([1, 4, 1, 28, 28])
40 
 torch.Size([1, 32, 1, 1]) 
 torch.Size([4, 32, 14, 14])
43 
 torch.Size([4, 1, 28, 28])
44 
 torch.Size([4, 28, 28])
46 
 torch.Size([1, 1, 20, 20]) 
 torch.Size([1, 20, 20])
48 
 torch.Size([1, 1, 20, 20]) 
 torch.Size([1, 1, 20, 20])
59 
 tensor([[[[0.0641]],

         [[0.6581]],

         [[0.9605]],

         [[0.0330]]]]) 
 tensor([[[[0.9558, 0.8657],
          [0.7481, 0.0709]],

         [[0.5156, 0.5102],
          [0.5354, 0.6373]],

         [[0.7915, 0.0628],
          [0.7903, 0.8029]],

         [[0.0228, 0.1665],
          [0.5545, 0.2742]]],


        [[[0.4379, 0.0240],
          [0.7396, 0.7296]],

         [[0.8305, 0.6955],
          [0.5117, 0.8612]],

         [[0.2120, 0.6087],
          [0.5026, 0.8842]],

         [[0.2343, 0.9974],
          [0.9827, 0.8795]]]])
62 
 tensor([[[[0.0641, 0.0641],
          [0.0641, 0.0641]],

         [[0.6581, 0.6581],
          [0.6581, 0.6581]],

         [[0.9605, 0.9605],
          [0.9605, 0.9605]],

         [[0.0330, 0.0330],
          [0.0330, 0.0330]]],


        [[[0.0641, 0.0641],
          [0.0641, 0.0641]],

         [[0.6581, 0.6581],
          [0.6581, 0.6581]],

         [[0.9605, 0.9605],
          [0.9605, 0.9605]],

         [[0.0330, 0.0330],
          [0.0330, 0.0330]]]]) 
 torch.Size([1, 4, 1, 1])
64 
 tensor([[[[0.0641],
          [0.0641]],

         [[0.6581],
          [0.6581]],

         [[0.9605],
          [0.9605]],

         [[0.0330],
          [0.0330]]]]) 
 torch.Size([1, 4, 1, 1])
69 
 torch.Size([4, 16, 8, 8])
75 
 tensor([[ 0.5465,  1.4349, -2.3023, -1.0372],
        [-0.4149,  2.9526, -1.0680, -1.2509],
        [ 1.2386, -0.6260,  0.8532,  1.8979]]) 
 tensor([[ 0.5465, -0.4149,  1.2386],
        [ 1.4349,  2.9526, -0.6260],
        [-2.3023, -1.0680,  0.8532],
        [-1.0372, -1.2509,  1.8979]])
78 
 torch.Size([1, 2, 4, 4]) 
 torch.Size([4, 2, 4, 1])
81 
 torch.Size([1, 2, 3, 4]) 
 torch.Size([4, 3, 2, 1])
91 
 tensor([[-1.1239e+00,  9.9837e-01, -1.2058e+00,  1.6464e+00],
        [-6.5159e-01, -4.2966e-01, -8.9623e-05, -5.9659e-01]]) 
 tensor([-0.2841])
93 
 (tensor([[-1.1239e+00,  9.9837e-01, -1.2058e+00,  1.6464e+00],
        [-6.5159e-01, -4.2966e-01, -8.9623e-05, -5.9659e-01]]), tensor([[-0.2841, -0.2841, -0.2841, -0.2841],
        [-0.2841, -0.2841, -0.2841, -0.2841]]))

Process finished with exit code 0

posted @ 2020-08-14 15:47  博0_oer~  阅读(72)  评论(0编辑  收藏  举报