Fork me on GitHub

PyTorch中.view()与.reshape()方法以及.resize_()方法的对比

前言

本文对PyTorch的.view()方法和.reshape()方法还有.resize_()方法进行了分析说明,关于本文出现的view和copy的语义可以看一下我之前写的文章,传送门:



torch.Tensor.reshape() vs. torch.Tensor.view()

  • 相同点:从功能上来看,它们的作用是相同的,都是将原张量元素(按顺序)重组为新的shape
  • 区别在于:
    • .view()方法只能改变连续的(contiguous)张量,否则需要先调用.contiguous()方法,而.reshape()方法不受此限制
    • .view()方法返回的张量与原张量共享基础数据(存储器,注意不是共享内存地址,详见代码 ),而.reshape()方法返回的是原张量的copy还是view(即是否跟原张量共享存储),事先是不知道的,如果可以返回view,那么.reshape()方法返回的就是原张量的view,否则返回的就是copy

–> 因此,为避免语义冲突:

  1. 如果需要原张量的拷贝(copy),就使用.clone()方法
  2. 而如果需要原张量的视图(view),就使用.view()方法
  3. 如果想要原张量的视图(view),但是原张量不连续(contiguous),不过原张量拥有兼容的步长(strides),此时可以考虑使用.reshape()函数
a = torch.randint(0, 10, (3, 4))
"""
Out:
tensor([[3, 7, 1, 3],
        [6, 4, 1, 3],
        [8, 8, 5, 7]])
"""

b = a.view(2, 6)
"""
Out:
tensor([[3, 7, 1, 3, 6, 4],
        [1, 3, 8, 8, 5, 7]])
"""

c = a.reshape(2, 6)
"""
Out:
tensor([[3, 7, 1, 3, 6, 4],
        [1, 3, 8, 8, 5, 7]])
"""

# 非严格意义上讲,id可以认为是对象的内存地址
print(id(a)==id(b), id(a)==id(c), id(b)==id(c))
"""
前提:python的变量和数据是保存在不同的内存空间中的,PyTorch中的Tensor的存储也是类似的机制,tensor相当于python变量,保存了tensor的形状(size)、步长(stride)、数据类型(type)等信息(或其引用),当然也保存了对其对应的存储器Storage的引用,存储器Storage就是对数据data的封装。
viewed对象和reshaped对象都存储在与原始对象不同的地址内存中,但是它们共享存储器Storage,也就意味着它们共享基础数据。
"""
print(id(a.storage())==id(b.storage()), 
	  id(a.storage())==id(c.storage()),
	  id(b.storage())==id(c.storage()))
"""
Out:
False False False
True True True
"""

a[0]=0
print(a, b, c)
"""
Out:
tensor([[0, 0, 0, 0],
        [6, 4, 1, 3],
        [8, 8, 5, 7]])
tensor([[0, 0, 0, 0, 6, 4],
        [1, 3, 8, 8, 5, 7]])
tensor([[0, 0, 0, 0, 6, 4],
        [1, 3, 8, 8, 5, 7]])
"""

c[0]=1
print(a, b, c)
"""
Out:
tensor([[1, 1, 1, 1],
        [1, 1, 1, 3],
        [8, 8, 5, 7]])
tensor([[1, 1, 1, 1, 1, 1],
        [1, 3, 8, 8, 5, 7]])
tensor([[1, 1, 1, 1, 1, 1],
        [1, 3, 8, 8, 5, 7]])
"""

  

 


torch.Tensor.resize_()

torch.Tensor.resize_() 方法的功能跟.reshape() / .view()方法的功能一样,也是将原张量元素(按顺序)重组为新的shape。

当resize前后的shape兼容时,返回原张量的视图(view);当目标大小(resize后的总元素数)大于当前大小(resize前的总元素数)时,基础存储器的大小将改变(即增大),以适应新的元素数,任何新的内存(新元素值)都是未初始化的;当目标大小(resize后的总元素数)小于当前大小(resize前的总元素数)时,基础存储器的大小保持不变,返回目标大小的元素重组后的张量,未使用的元素仍然保存在存储器中,如果再次resize回原来的大小,这些元素将会被重新使用。

(这里说的shape兼容的意思是:resize前后的shape包含的总元素数是一致的,即resize前后的shape的所有维度的乘积是相同的。如resize前,shape为(1, 2 ,3),那resize之后的张量的总元素数需要是1*2*3,故目标shape可以是(2, 3), 可以是(3, 2, 1),可以是(2, 1, 3)等尺寸。)

–> 文字说明有点干燥,看点例子感受一下:

a = torch.arange(24).view(4, 6)
"""
Out:
tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23]])
"""

a.resize_(6, 4)
"""
Out:
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]])
"""

a.resize_(3, 3)
"""
Out:
tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])
"""

a.resize_(7, 4)
"""
Out:
tensor([[              0,               1,               2,               3],
        [              4,               5,               6,               7],
        [              8,               9,              10,              11],
        [             12,              13,              14,              15],
        [             16,              17,              18,              19],
        [             20,              21,              22,              23],
        [140720147688480, 140720141167152,               1,               0]])
"""

  

 

ps(官方解释,不是很能理解): 这是一个底层方法。存储被重新解释为c连续的,忽略当前的步长(除非目标大小等于当前大小,在这种情况下张量保持不变)

更多时候应该使用.view() / .reshape() / .set_()方法来替代此方法



参考文献:

What’s the difference between reshape and view in pytorch?

 

原文链接:https://blog.csdn.net/weixin_43002433/article/details/104299896

posted @ 2021-02-15 10:33  stardsd  阅读(10198)  评论(0编辑  收藏  举报