Pytorch——Tensor的储存机制以及view()、reshape()、reszie_()三者的关系和区别
本文主要介绍Pytorch中Tensor的储存机制,在搞懂了Tensor在计算机中是如何存储之后我们会进一步来探究tensor.view()、tensor.reshape()、tensor.reszie_(),她们都是改变了一个tensor的“形状”,但是他们之间又有着些许的不同,这些不同常常会导致我们程序之中出现很多的BUG。
一、Tensor的储存机制
tensor在电脑的储存,分为两个部分(也就是说一个tensor占用了两个内存位置),一个内存储存了这个tensor的形状size、步长stride、数据的索引等信息,我们把这一部分称之为头信息区(Tensor);另一个内存储的就是真正的数据,我们称为存储区 (Storage)。换句话说,一旦定义了一个tensor,那这个tensor将会占据两个内存位置,用于存储。
要注意,如果我们把一个tensorA进行切片,截取,修改之后通过"="赋值给B,那么这个时候tensorB其实是和tensorA是共享存储区 (Storage),唯一不同的是头信息区(Tensor)不同。下面我们直接看代码来理解。其中tensor.storage().data_ptr()是用于获取tensor储存区的首元素内存地址的。
1 2 3 4 5 6 7 | A = torch.arange( 5 ) # tensor([0, 1, 2, 3, 4]) B = A[ 2 :] # 对A进行截取获得:tensor([2, 3, 4]) print (A) print (B)tensor([ 0 , 1 , 2 , 3 , 4 ]) tensor([ 2 , 3 , 4 ]) print (A.storage().data_ptr()) print (B.storage().data_ptr()) 2076006947200 2076006947200 |
我们可以很直观的看到,A和B的储存区的内存地址是一样的,因此她们是共享数据的,下面这个例子更加直观。
1 2 3 4 5 6 7 8 | import torch A = torch.arange( 5 ) # tensor([0, 1, 2, 3, 4]) B = A[ 2 :] # 对A进行截取获得:tensor([2, 3, 4]) B[ 1 ] = 100 # 修改B的第2位置元素为100 print (A) print (B)tensor([ 0 , 1 , 2 , 100 , 4 ]) tensor([ 2 , 100 , 4 ]) |
因此我们可以得出结论,通过=直接赋值的操作其实就是“浅拷贝”(这里注意和list的切片区分,list使用A[2:],是可以得到新的一个list的)
二、tensor的stride()属性、storage_offset()属性
为了更好的解释tensor的reshape(),以及view()的操作,我们还需要了解下tensor的stride属性。刚才上面我们提到了,tensor为了节约内存,很多操作其实都是在更改tensor的头信息区(Tensor),因为头信息区里面包含了如何组织数据,以及从哪里开始组织。其中stride()和storage_offset()属性分别代表的就是步长以及初始偏移量。
storage_offset()属性
表示tensor的第一个元素与真实存储区(storage)的第一个元素的偏移量。例如下面的例子:
1 2 3 4 5 6 7 8 9 10 11 | import torch A = torch.arange( 5 ) B = A[ 2 :] C = A[ 1 :] print (A) print (B) print (C)tensor([ 0 , 1 , 2 , 3 , 4 ]) tensor([ 2 , 3 , 4 ]) tensor([ 1 , 2 , 3 , 4 ]) print (B.storage_offset()) print (C.storage_offset()) 2 1 |
我们可以看到tensorB和tensorC都是从A切片而来的,她们俩的存储区 (Storage)是和A共享的,只不过B的第一个元素,与存储区 (Storage)的首元素相差了2个位置(也就是储存区的index=2开始),C的第一个元素与存储区 (Storage)的首元素相差了1个位置。
stride()属性
这个属性比较难理解,直接翻译官方文档就是:stride是在指定维度dim中从一个元素跳到下一个元素所必需的步长。直接上例子:
1 2 3 4 5 6 7 8 9 10 11 12 | import torch A = torch.rand( 2 , 3 ) # 生成2*3的随机数 print (A) print (A.storage()) # 打印A的储存区真实的数据打印A: tensor([[0.8438, 0.2782, 0.9584], [ 0.2089 , 0.0259 , 0.3666 ]]) 0.8437800407409668 0.2781521677970886 0.9583932757377625 0.2088671326637268 0.025857746601104736 0.366576611995697 [torch.FloatStorage of size 6 ] print (A.stride())( 3 , 1 ) |
主要是理解这个(3,1)指的是什么意思。这里的3指的是A[i][j]到A[i+1][j]这两个数字在存储区真实数据排列中是相差3的(例如A[0][0]=0.8438与A[1][0]=0.2089这两个数字在储存区中位次相差了3);这里的1是指A[i][j]与A[i][j+1]这两个数字在储存区的真实数据排列中相差1(例如A[0][0]=0.8438与A[0][1]=0.2781这两个数字在储存区中位次相差1)。如果还没有理解,加下来我们试一下对于3维数据看看他们的stride()属性。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | import torch A = torch.rand( 2 , 3 , 4 ) # 生成2*3*4的随机数 print ( "打印A:" ,A) print (A.storage()) # 打印A的储存区真实的数据打印A: tensor([[[0.4303, 0.7474, 0.8649, 0.5006], [ 0.2716 , 0.9966 , 0.7765 , 0.6737 ], [ 0.5515 , 0.2274 , 0.9791 , 0.1940 ]], [[ 0.6401 , 0.7746 , 0.5124 , 0.0258 ], [ 0.8576 , 0.9118 , 0.9504 , 0.4675 ], [ 0.9359 , 0.0687 , 0.2457 , 0.3604 ]]]) 0.4302864074707031 0.747403085231781 0.8648527264595032 0.500649631023407 0.2716004252433777 0.9965775609016418 0.7765441536903381 0.6737198233604431 0.5515168905258179 0.2273930311203003 0.9791405200958252 0.19399094581604004 0.6401097774505615 0.7746065855026245 0.512383759021759 0.02578103542327881 0.8575518727302551 0.911821186542511 0.9503545165061951 0.4674733877182007 0.9358749389648438 0.06866037845611572 0.24573636054992676 0.3603515625print (A.stride())( 12 , 4 , 1 ) |
输出有点长,大家对照着看,由于我们A的size是3维度的,因此我们A.stride()也是个三元组,那如果A是4维呢?(A.stride()一定就是4元组了)。这里的12表示就是A[i][j][k]与A[i+1][j][k]这两个数字在真实储存区的数据排布中相差12,大家可以对照的找几个数字试试。同样的道理这里的4表示A[i][j][k]与A[i][j+1][k]这两个数字在真实储存区的数据排布中相差4。最后1表示什么我就不说啦。
好了终于说完这个很难的知识点了,接下来就进入正题,view()、reshape()、reszie_()三者的关系和区别。
三、view()、reshape()、reszie_()三者的关系和区别
其中view()和reshape()是官方比较推荐使用的方式,而resize_()官方在文档中说到不太推荐使用,具体原因一会说到。这三个方法都是可以完成对以一个tensor重新排列,没错是重新排列,其实她们本质上都没有改变tensor的存储区 (Storage)的真实数据的排列(除了一些特殊情况下会使得存储区发生改变,这就是她们间的区别)。
view()
从字面上来说就是"视图"的意思,就是把存储区 (Storage)的真实数据,根据某种排列方式”展示“给你看罢了,也就是仅仅改变了头信息区(Tensor),真实数据的储存地址是没有改变的。直接上例子。
1 2 3 4 5 6 7 8 9 | import torch A = torch.arange( 6 ) B = A.view( 2 , 3 ) print (A) print (B)tensor([ 0 , 1 , 2 , 3 , 4 , 5 ]) tensor([[ 0 , 1 , 2 ], [ 3 , 4 , 5 ]]) print (A.storage().data_ptr()) print (B.storage().data_ptr()) 1881582170752 1881582170752 |
可以看到,A和B的真实数据的内存地址都是一样的,下面我们进一步打印一下A,B两个tensor真实数据的排列。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | print (A.storage()) print (B.storage()) 0 1 2 3 4 5 [torch.LongStorage of size 6 ] 0 1 2 3 4 5 [torch.LongStorage of size 6 ] |
可以看到,是完全一样的。更进一步打印一下A,B的stride()属性
1 2 3 | print (A.stride()) print (B.stride())( 1 ,) ( 3 , 1 ) |
没问题和前面说的是一样的。
总结一下,view()函数主要就是更改了tensor中的stride()属性,这样从而影响了tensor的显示,但是从本质上来说A,B还是共用真实数据的存储区 (Storag)的。
reshape()
为了解释view()和reshape()的区别,我们还需要知道一个知识:tensor的连续性。tensor又不是函数哪里来什么连续性?其实tensor的连续性说的就是stride()属性和size()属性(tensor维度)之间的关系。
前一小结已经说了对于一个高维的tensor,stride()指的是:指定维度dim中从一个元素跳到下一个元素所必需的步长。一般来说我们最后一个维度步长应该是1(其实我们前面的例子我们应该也能发现,例子中所有tensor.stride()返回的元组最后一个元素都是1),对吧,因为是按顺序排列的嘛。但是当一个tensor涉及到转置(tensor.t(),tensor.transpose(),tensor.permute())这些操作都会使得tensor失去连续性这个性质。我们直接来看看例子吧。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 | a = torch.arange( 6 ).view( 2 , 3 ) b = a.t() c = a.transpose( 1 , 0 ) d = a.permute( 1 , 0 ) print ( 'b是:' ,b) print ( 'c是:' ,c) print ( 'd是:' ,d)b是: tensor([[ 0 , 3 ], [ 1 , 4 ], [ 2 , 5 ]]) c是: tensor([[ 0 , 3 ], [ 1 , 4 ], [ 2 , 5 ]]) d是: tensor([[ 0 , 3 ], [ 1 , 4 ], [ 2 , 5 ]]) print (a.stride()) print (b.stride()) print (c.stride()) print (d.stride())( 3 , 1 ) ( 1 , 3 ) ( 1 , 3 ) ( 1 , 3 ) |
这里我就不验证她们是不是同一个存储区 (Storage)了,大家下来可以验证下(其实就是同一个)。我们可以看到b,c,d三个tensor的stride()属性和a是不一样的,根据stride()的定义大家应该是很容易知道b,c,d返回的stride()是什么意思吧。那为什么说b,c,d的tensor就不连续了呢?是因为她们不满足张量的连续性条件了。连续性条件如下:
这是什么意思呢?拿b举例就是,b的stride=(3,1),b的size=(3,2),那么stride[0] != stride[1] * size[1]的,因此b是不满足连续性条件的。如果从直观上来感觉来"连续"的意思就是,“我”旁边的数字就应该是“我”真实储存区旁边的数据,例如b[0][0]=0,但是b[0][1]=3,0和3这两个数字在真实的存储区 (Storage)不是挨着的啊,所以叫做不连续。
那不满足连续性有什么后果呢?后果就是不满足连续性的tensor是无法使用view()方法的。换句话说,上面例子中的b,c,d都无法再使用view()方法了。
1 | e = b.view( 1 , 6 )RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. |
c,d大家自己下来试一试。所以对于一个tensor是不是连续就意味着他能不能使用view()方法。
那有什么办法让b使用view()方法呢?那就是把b连续化(使用tensor.contiguous()方法)。上例子。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | a = torch.arange( 6 ).view( 2 , 3 ) b = a.t() c = b.contiguous() print (a.storage()) print (c.storage()) 0 1 2 3 4 5 [torch.LongStorage of size 6 ] 0 3 1 4 2 5 [torch.LongStorage of size 6 ] print (a.storage().data_ptr()) print (c.storage().data_ptr()) 1881582182144 1881582172928 |
其实tensor.contiguous()方法是创造了一个新的tensor(全新的,连存储区都不共用的tensor),这里的c就是从b得到的连续的tensor了,大家可以打印下c.stride(),会得到(2,1),这样再根据c的size就能发现,c是满足上面提到的连续性公式的。
了解以上知识之后,reshape()和view()的差别就来了,view()是没法对非连续性的tensor使用的(会报错),但是reshape()是可以对非连续性tensor使用的。换句话说
- 当tensor满足连续性要求时,reshape() = view(),和原来tensor共用内存
- 当tensor不满足连续性要求时,reshape() = contiguous() + view(),会产生新的存储区的tensor,与原来tensor不共用内存
这就是view()和reshape()的差别了。
reszie_()
那这一个又和前面那俩有啥关系的呢?从官方文档上来说,它是不希望我们使用这个resize_()的,如图。
前面说到的reshape和view都必须要用到全部的原始数据,比如你的原始数据只有12个,无论你怎么变形都必须要用到12个数字们,不能多不能少。因此你就不能把只有12个数字的tensor强行reshape成2*5的维度的tensor。但是resize_()可以做到,无论你存储区原始有多少个数字,我都能变成你想要的维度,数字不够怎么办?随机产生凑!数字多了怎么办?就取我需要的部分!上例子。
多说一句a.resize_()是会改变a的哟,换句话说,a.resize_(2,3)之后,a就不再是1*7的维度了,而是2*3的维度了。但是a的储存区还是原来的储存区
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | a = torch.arange( 7 ) print ( "变换前a的储存区地址:" ,a.storage().data_ptr()) b = a.resize_( 2 , 3 ) print ( '这是新的a:' ,a)变换前a的储存区地址: 1881579251648 这是新的a: tensor([[ 0 , 1 , 2 ], [ 3 , 4 , 5 ]]) print (a.storage()) print (b.storage()) 0 1 2 3 4 5 6 [torch.LongStorage of size 7 ] 0 1 2 3 4 5 6 [torch.LongStorage of size 7 ] print ( '变换后a的储存区地址' ,a.storage().data_ptr()) print (b.storage().data_ptr())变换后a的储存区地址 1881579251648 1881579251648 |
你会发现尽管a的”长相“(数字个数也从7个变成了6个)被改变了,但是存储区依旧是没变的(要注意到真实存储区的个数也没变哟还是7个),因此我们可以说resize_()再进行变换的时候如果数字多余了,会截取我们需要的数据量,多余的数据量并没有被舍弃。
再来看看,当我reszie_多于原来的数据的时候发生什么。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | a = torch.arange( 7 ) print ( "变换前a的储存区地址:" ,a.storage().data_ptr()) b = a.resize_( 3 , 4 ) print (a.storage()) print (b.storage())变换前a的储存区地址: 1881579250944 0 1 2 3 4 5 6 7667809 6815836 [torch.LongStorage of size 9 ] 0 1 2 3 4 5 6 7667809 6815836 [torch.LongStorage of size 9 ] print ( '变换后a的储存区地址' ,a.storage().data_ptr()) print (b.storage().data_ptr())变换后a的储存区地址 1881582026048 1881582026048 |
这个时候resize_()前后a的储存区地址是发生了变化的哟。
下一个问题:resize_()可不可以对不连续的tensor使用呢?
答案是可以,并且并不会改变原来tensor的内存。当tensor是不连续的时候,采用reshape()会生成个新的存储区的,采用resize_()则不会改变存储区。那这两者有啥区别呢?其实很好解释,reshape是尊重tensor,把存储区改了来将就tensor的reshape的长相,并使得连续。而resize_是:不改存储区,但是“用户”又想要看到想看到的长相,行,那我就把存储区的数按照你想看到的长相排列吧。直接上例子。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | import torch a = torch.arange( 6 ).view( 2 , 3 ) b = a.t() #b是这个样子的:tensor([[0, 3], # [1, 4], # [2, 5]]) c = b.reshape( 1 , 6 ) e = b.resize_( 1 , 6 ) print ( "c的存储区:" ,c.storage().data_ptr()) print ( 'e的存储区:' ,e.storage().data_ptr())c的存储区: 2237602017664 e的存储区: 2237602025472print ( "c的存储区真实数据排布:" ,c.storage()) print ( "e的存储区真实数据排布:" ,e.storage())c的存储区真实数据排布: 0 3 1 4 2 5 [torch.LongStorage of size 6 ] e的存储区真实数据排布: 0 1 2 3 4 5 [torch.LongStorage of size 6 ] print ( '我是c:' ,c) print ( '我是e:' ,e)我是c: tensor([[ 0 , 3 , 1 , 4 , 2 , 5 ]]) 我是e: tensor([[ 0 , 1 , 2 , 3 , 4 , 5 ]]) |
可以很直观的看出来,如果tensor是不连续的时候,reshape和resize_的差别了吧。
四、总结
最后总结一下view()、reshape()、reszie_()三者的关系和区别。
- view()只能对满足连续性要求的tensor使用。
- 当tensor满足连续性要求时,reshape() = view(),和原来tensor共用内存。
- 当tensor不满足连续性要求时,reshape() = contiguous() + view(),会产生新的存储区的tensor,与原来tensor不共用内存。
- resize_()可以随意的获取任意维度的tensor,不用在意真实数据的个数限制,但是不推荐使用。
参考博客:
PyTorch:view() 与 reshape() 区别详解_Flag_ing的博客-CSDN博客_reshape和view
pytorch笔记(一)——tensor的storage()、stride()、storage_offset()_Zoran的博客-CSDN博客_pytorch stride()
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
· 一个奇形怪状的面试题:Bean中的CHM要不要加volatile?
· [.NET]调用本地 Deepseek 模型
· 一个费力不讨好的项目,让我损失了近一半的绩效!
· PowerShell开发游戏 · 打蜜蜂
· 在鹅厂做java开发是什么体验
· 百万级群聊的设计实践
· WPF到Web的无缝过渡:英雄联盟客户端的OpenSilver迁移实战
· 永远不要相信用户的输入:从 SQL 注入攻防看输入验证的重要性