pytorch之squeeze

转自:https://blog.csdn.net/xiexu911/article/details/80820028

1.torch.squeeze() 

只会去掉维度为1的那个维度。它只会去掉维度为1的维度,像下面的没有为1的维度,就不会改变:

>>> aaa=np.ones((3,2))
>>> aaa
array([[1., 1.],
       [1., 1.],
       [1., 1.]])
>>> aaa.squeeze()
array([[1., 1.],
       [1., 1.],
       [1., 1.]])

https://docs.scipy.org/doc/numpy/reference/generated/numpy.squeeze.html

 

 2.torch.unsqueeze()

>>> a=torch.tensor([1,2])
>>> a.size()
torch.Size([2])

>>> b=a.unsqueeze(1)
>>> b
tensor([[1],
        [2]])
>>> b.size()
torch.Size([2, 1])

>>> c=a.unsqueeze(0)
>>> c
tensor([[1, 2]])
>>> c.size()
torch.Size([1, 2])

就是在第i个维度上多加一个维度,对于b来说,是第二个维度,对于c来说,是第一个维度。

原来是[2]长的list那种形式,在unsqueeze(0)的时候就变成了[1,2],在unsqueeze(1)的时候就变成了[2,1]

 

posted @ 2021-06-22 15:18  lypbendlf  阅读(111)  评论(0编辑  收藏  举报