torch.Tensor 的一些注意点

1. 广播机制

2. expand和repeat

3. view和reshape

 

1. 广播机制

广播机制牢记一点就可以了:从末尾开始连续找第一个不相同的维度容量开始扩充,直到扩充成维度相同。

 

比如:

a = torch.rand(3, 4)
b = torch.rand(4)
c1 = a + b #size:[3,4]

a = torch.rand(1, 3)
b = torch.rand(4)
c1 = a + b #size:Runtimee Error

a = torch.rand(1, 3)
b = torch.rand(4, 1)
c1 = a + b #size:[4,3]

2. expand和repeat

 

区别:1)前者expand返回的不是新内存,只是一个视图而已,后者repeat重复是拷贝内存的,会返回一个新的内存地址

 

           2)前者expand括号内的参数是扩展后的维度,后者repeat括号内参数是在对应维度上的重复次数

 

比如:

 

x1 = torch.tensor([1, 2, 3])
x2 = x1.expand(2, 3)

x2[1,1] = 5
print(x2)#[1 5 3; 1 5 3]
print(id(x1.data)) #1794517976552
print(id(x2.data)) #1794517976552


y1 = torch.tensor([1, 2, 3])
y2 = y1.repeat(2, 1)

y2[1,1] = 5
print(y2)#[1 2 3; 1 5 3]
print(id(y1.data)) #1794517592328
print(id(y2.data)) #1794517601160


z = torch.tensor([1, 2, 3])
z = z.repeat(2, 3)

z[1,1] = 5
print(z)#[1, 2, 3, 1, 2, 3, 1, 2, 3;1, 5, 3, 1, 2, 3, 1, 2, 3]

 

3. view和reshape

区别:view要求Tensor的存储区是连续的,如果不是,则会报错。比如对一个tensor先用了permute方法,再接view方法,会产生这样的错误。如果要避免,就需要在view之前用continguous方法。

而reshape相当于continguous+view(不满足连续性)或者直接等于view(满足连续性)。所以view是显然不如reshape的,如果说view有什么优势的话,大概就是用view可以断定直接新生产的Tensor一定不是一个copy,而是一个原Tensor的视图(共享内存)。

比如:

a = torch.arange(9).reshape(3, 3)  # 初始化张量 a
b = a.permute(1,0) # 初始化张量 b
c = a.permute(1,0).contiguous()

print('a:', a)
print('ptr of storage of a:', a.storage().data_ptr())
print('b:', b)
print('ptr of storage of b:', b.storage().data_ptr())#b和a共享存储空间
print('c:', c)
print('ptr of storage of c:', c.storage().data_ptr())#c用了另外的存储空间

#b = b.view(1,9)#直接用会报错,因为此时b已经不连续了
b = b.reshape(1,9)
c = c.view(1,9)

print('a:', a)
print('ptr of storage of a:', a.storage().data_ptr())
print('b:', b)
print('ptr of storage of b:', b.storage().data_ptr())#b用了另外的存储空间,此时reshape=contiguous+view
print('c:', c)
print('ptr of storage of c:', c.storage().data_ptr())#c的存储空间不变

b = a.view(1,9)
c = a.reshape(1,9)
print('a:', a)
print('ptr of storage of a:', a.storage().data_ptr())
print('b:', b)
print('ptr of storage of b:', b.storage().data_ptr())
print('c:', c)
print('ptr of storage of c:', c.storage().data_ptr())#a,b,c的存储空间不变,此时reshape就等于view

posted on 2021-08-11 10:25  博闻强记2010  阅读(246)  评论(0编辑  收藏  举报

导航