pytorch中的squeeze函数、unsqueeze函数
1、 squeeze()函数
- 去除size为1的维度,包括行和列
- 至于维度大于等于2时,squeeze()不起作用
- torch.unsqueeze(A,N)函数的作用减少数组A指定位置N的维度,如果指定的维度大于1,那么将操作无效,如果不指定维度N,那么将删除所有维度为1的维度
行变的例子
>>> torch.rand(4, 1, 3)
(0 ,.,.) =
0.5391 0.8523 0.9260
(1 ,.,.) =
0.2507 0.9512 0.6578
(2 ,.,.) =
0.7302 0.3531 0.9442
(3 ,.,.) =
0.2689 0.4367 0.6610
[torch.FloatTensor of size 4x1x3]
>>> torch.rand(4, 1, 3).squeeze()
0.0801 0.4600 0.1799
0.0236 0.7137 0.6128
0.0242 0.3847 0.4546
0.9004 0.5018 0.4021
[torch.FloatTensor of size 4x3]
列变的例子
>>> torch.rand(4, 3, 1)
(0 ,.,.) =
0.7013
0.9818
0.9723
(1 ,.,.) =
0.9902
0.8354
0.3864
(2 ,.,.) =
0.4620
0.0844
0.5707
(3 ,.,.) =
0.5722
0.2494
0.5815
[torch.FloatTensor of size 4x3x1]
>>> torch.rand(4, 3, 1).squeeze()
0.8784 0.6203 0.8213
0.7238 0.5447 0.8253
0.1719 0.7830 0.1046
0.0233 0.9771 0.2278
[torch.FloatTensor of size 4x3]
不变的例子
>>> torch.rand(4, 3, 2)
(0 ,.,.) =
0.6618 0.1678
0.3476 0.0329
0.1865 0.4349
(1 ,.,.) =
0.7588 0.8972
0.3339 0.8376
0.6289 0.9456
(2 ,.,.) =
0.1392 0.0320
0.0033 0.0187
0.8229 0.0005
(3 ,.,.) =
0.2327 0.6264
0.4810 0.6642
0.8625 0.6334
[torch.FloatTensor of size 4x3x2]
>>> torch.rand(4, 3, 2).squeeze()
(0 ,.,.) =
0.0593 0.8910
0.9779 0.1530
0.9210 0.2248
(1 ,.,.) =
0.7938 0.9362
0.1064 0.6630
0.9321 0.0453
(2 ,.,.) =
0.0189 0.9187
0.4458 0.9925
0.9928 0.7895
(3 ,.,.) =
0.5116 0.7253
0.0132 0.6673
0.9410 0.8159
[torch.FloatTensor of size 4x3x2]
2、unsqueeze()函数
torch.unsqueeze()函数的作用增加数组A指定位置N的维度。
例如:
两行三列的数组A维度为(2,3),那么这个数组就有三个位置可以增加维度,分别是( [位置0] 2,[位置1] 3 [位置2] )或者是 ( [位置-3] 2,[位置-2] 3 [位置-1] ),如果执行 torch.unsqueeze(A,1),数据的维度就变为了 (2,1,3)
a=torch.randn(1,3)
print(a.shape)
b=torch.unsqueeze(a,0)
print(b.shape)
c=torch.unsqueeze(a,1)
print(c.shape)
d=torch.unsqueeze(a,2)
print(d.shape)
输出
torch.Size([1, 3])
torch.Size([1, 1, 3])
torch.Size([1, 1, 3])
torch.Size([1, 3, 1])