PyTorch中.view()与.reshape()方法以及.resize_()方法的对比
前言
本文对PyTorch的.view()方法和.reshape()方法还有.resize_()方法进行了分析说明,关于本文出现的view和copy的语义可以看一下我之前写的文章,传送门:
-
view和copy对比:浅谈PyTorch/Numpy中view和copy/clone的区别
torch.Tensor.reshape() vs. torch.Tensor.view()
- 相同点:从功能上来看,它们的作用是相同的,都是将原张量元素(按顺序)重组为新的shape。
- 区别在于:
- .view()方法只能改变连续的(contiguous)张量,否则需要先调用.contiguous()方法,而.reshape()方法不受此限制;
- .view()方法返回的张量与原张量共享基础数据(存储器,注意不是共享内存地址,详见代码 ),而.reshape()方法返回的是原张量的copy还是view(即是否跟原张量共享存储),事先是不知道的,如果可以返回view,那么.reshape()方法返回的就是原张量的view,否则返回的就是copy。
–> 因此,为避免语义冲突:
- 如果需要原张量的拷贝(copy),就使用.clone()方法;
- 而如果需要原张量的视图(view),就使用.view()方法;
- 如果想要原张量的视图(view),但是原张量不连续(contiguous),不过原张量拥有兼容的步长(strides),此时可以考虑使用.reshape()函数。
a = torch.randint(0, 10, (3, 4)) """ Out: tensor([[3, 7, 1, 3], [6, 4, 1, 3], [8, 8, 5, 7]]) """ b = a.view(2, 6) """ Out: tensor([[3, 7, 1, 3, 6, 4], [1, 3, 8, 8, 5, 7]]) """ c = a.reshape(2, 6) """ Out: tensor([[3, 7, 1, 3, 6, 4], [1, 3, 8, 8, 5, 7]]) """ # 非严格意义上讲,id可以认为是对象的内存地址 print(id(a)==id(b), id(a)==id(c), id(b)==id(c)) """ 前提:python的变量和数据是保存在不同的内存空间中的,PyTorch中的Tensor的存储也是类似的机制,tensor相当于python变量,保存了tensor的形状(size)、步长(stride)、数据类型(type)等信息(或其引用),当然也保存了对其对应的存储器Storage的引用,存储器Storage就是对数据data的封装。 viewed对象和reshaped对象都存储在与原始对象不同的地址内存中,但是它们共享存储器Storage,也就意味着它们共享基础数据。 """ print(id(a.storage())==id(b.storage()), id(a.storage())==id(c.storage()), id(b.storage())==id(c.storage())) """ Out: False False False True True True """ a[0]=0 print(a, b, c) """ Out: tensor([[0, 0, 0, 0], [6, 4, 1, 3], [8, 8, 5, 7]]) tensor([[0, 0, 0, 0, 6, 4], [1, 3, 8, 8, 5, 7]]) tensor([[0, 0, 0, 0, 6, 4], [1, 3, 8, 8, 5, 7]]) """ c[0]=1 print(a, b, c) """ Out: tensor([[1, 1, 1, 1], [1, 1, 1, 3], [8, 8, 5, 7]]) tensor([[1, 1, 1, 1, 1, 1], [1, 3, 8, 8, 5, 7]]) tensor([[1, 1, 1, 1, 1, 1], [1, 3, 8, 8, 5, 7]]) """
torch.Tensor.resize_()
torch.Tensor.resize_() 方法的功能跟.reshape() / .view()方法的功能一样,也是将原张量元素(按顺序)重组为新的shape。
当resize前后的shape兼容时,返回原张量的视图(view);当目标大小(resize后的总元素数)大于当前大小(resize前的总元素数)时,基础存储器的大小将改变(即增大),以适应新的元素数,任何新的内存(新元素值)都是未初始化的;当目标大小(resize后的总元素数)小于当前大小(resize前的总元素数)时,基础存储器的大小保持不变,返回目标大小的元素重组后的张量,未使用的元素仍然保存在存储器中,如果再次resize回原来的大小,这些元素将会被重新使用。
(这里说的shape兼容的意思是:resize前后的shape包含的总元素数是一致的,即resize前后的shape的所有维度的乘积是相同的。如resize前,shape为(1, 2 ,3),那resize之后的张量的总元素数需要是1*2*3,故目标shape可以是(2, 3), 可以是(3, 2, 1),可以是(2, 1, 3)等尺寸。)
–> 文字说明有点干燥,看点例子感受一下:
a = torch.arange(24).view(4, 6) """ Out: tensor([[ 0, 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10, 11], [12, 13, 14, 15, 16, 17], [18, 19, 20, 21, 22, 23]]) """ a.resize_(6, 4) """ Out: tensor([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11], [12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]) """ a.resize_(3, 3) """ Out: tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) """ a.resize_(7, 4) """ Out: tensor([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11], [ 12, 13, 14, 15], [ 16, 17, 18, 19], [ 20, 21, 22, 23], [140720147688480, 140720141167152, 1, 0]]) """
ps(官方解释,不是很能理解): 这是一个底层方法。存储被重新解释为c连续的,忽略当前的步长(除非目标大小等于当前大小,在这种情况下张量保持不变)
更多时候应该使用.view() / .reshape() / .set_()方法来替代此方法
参考文献:
What’s the difference between reshape and view in pytorch?
原文链接:https://blog.csdn.net/weixin_43002433/article/details/104299896
如果这篇文章帮助到了你,你可以请作者喝一杯咖啡