pytorch中squeeze()和unsqueeze()函数

 

squeeze() 函数:从数组的形状中删除单维度条目,即把shape中为1的维度去掉

用法:numpy.squeeze(a,axis = None)

        a表示输入的数组;
        axis用于指定需要删除的维度,但是指定的维度必须为单维度,否则将会报错;
        axis的取值可为None 或 int 或 tuple of ints, 可选。若axis为空,则删除所有单维度的条目;
        返回值:数组;
        不会修改原数组;

    >>> a = e.reshape(1,1,10)
    >>> a
    array([[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]])
    >>> np.squeeze(a)
    array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
链接:https://blog.csdn.net/weixin_40730615/article/details/115488051

 

下面使用一个二维矩阵看下dim不同时呈现出的效果:

    # 创建一个3*4的全1二维tensor
    a = torch.ones(3,4)
    '''
    运行结果
    tensor([[1., 1., 1., 1.],
            [1., 1., 1., 1.],
            [1., 1., 1., 1.]])
    '''

在0维度上插入一个维度,可以看到现在a的形状变为[1, 3, 4],第0维度的大小默认是1

    a = a.unsqueeze(0)
    print(a.shape)
    '''
    运行结果
    tensor([[[1., 1., 1., 1.],
             [1., 1., 1., 1.],
             [1., 1., 1., 1.]]])
    torch.Size([1, 3, 4])
    '''

在最后一个维度上插入一个维度,形状变为[3, 4, 1]

    a = a.unsqueeze(a.dim())
    print(a.shape)
    '''
    运行结果
    tensor([[[1.],
             [1.],
             [1.],
             [1.]],
            [[1.],
             [1.],
             [1.],
             [1.]],
            [[1.],
             [1.],
             [1.],
             [1.]]])
    torch.Size([3, 4, 1])
    '''

 


————————————————

REF 转自

https://blog.csdn.net/ljwwjl/article/details/115342632

 

import torch
a = torch.ones(3,4)
################################
print(a.unsqueeze(0))  ## 1,3,4
tensor([[[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]]])
################################
print(a.unsqueeze(1))## 3,1,4
tensor([[[1., 1., 1., 1.]],

        [[1., 1., 1., 1.]],

        [[1., 1., 1., 1.]]])
################################
print(a.unsqueeze(2))# 3,4,1
tensor([[[1.],
         [1.],
         [1.],
         [1.]],

        [[1.],
         [1.],
         [1.],
         [1.]],

        [[1.],
         [1.],
         [1.],
         [1.]]])

 

posted @ 2023-10-08 08:17  emanlee  阅读(152)  评论(0编辑  收藏  举报