pytorch中的unsqueeze函数和squeeze函数
在pytorch中,我们经常对张量Tensor的维度进行压缩或者扩充(压缩或者扩充的维度为1)。其中经常使用的是squeeze()
函数和unsqueeze
函数;
squeeze在英文中的意思就是“挤、压”,所以故名思议,squeeze()
函数就是对张量的维度进行减少的操作,话不多说,我们直接看下例子:
import torch
#定义两个整型的张量a,b
a = torch.IntTensor([[1,2,3],[4,5,6]])
b = torch.IntTensor([[[1,2,3],[4,5,6]]])
#看一下a,b的形状
print(a.shape)
print(b.shape)
'''
===output===
torch.Size([2, 3])
torch.Size([1, 2, 3])
'''
#我们看到张量b比较膨胀,有三个维度:1*2*3,所以我们要挤压一下张量b的第0个维度(因为是1才能挤压,否则没有效果)
c = torch.squeeze(b,0) # 对应的维度为第0维
print(c.shape)
'''
===output===
torch.Size([2, 3])
'''
#那如果想想张量a膨胀一下,怎么办
c = torch.unsqueeze(a,0)
print(c.shape)
'''
===output===
torch.Size([1, 2, 3])
'''
#可以看到张量a在第0维也膨胀了, 如果你看不惯的话,再压缩一下它。
另外,squeeze()
函数和unsqueeze()
函数还有另一种写法,直接用张量类型的变量来调用这两个函数:
c = a.unsqueeze(0)
print(c.shape)
'''
===output===
torch.Size([1, 2, 3])
'''
你看出差别了么?这里直接用张量变量a
来调用了unsqueeze()
函数,当然squeeze()
也是一样的,不信你可以试试_
如果你喜欢的话...
如果读完我写的笔记有疑问或者想法,欢迎留下您的评论,我们一起交流、共同讨论、相互学习。如果这篇笔记让您有收获,愿您不吝打赏,您的鼓励是对我最大的肯定,也督促我记录更多质量更好的笔记。
![打赏码](https://www.cnblogs.com/images/cnblogs_com/datasnail/1294095/o_QQ%E6%88%AA%E5%9B%BE20180905125739.png)