pytorch中copy() clone() detach()
Torch 为了提高速度,向量或是矩阵的赋值是指向同一内存的
如果需要开辟新的存储地址而不是引用,可以用clone()进行深拷贝
区别
clone()
解释说明: 返回一个原张量的副本,同时不破坏计算图,它能够维持反向传播计算梯度,
并且两个张量不共享内存.一个张量上值的改变不影响另一个张量.
copy_()
解释说明: 比如x4.copy_(x2), 将x2的数据复制到x4,并且会
修改计算图,使得反向传播自动计算梯度时,计算出x4的梯度后
再继续前向计算x2的梯度. 注意,复制完成之后,两者的值的改变互不影响,
因为他们并不共享内存.
detach()
解释说明: 比如x4 = x2.detach(),返回一个和原张量x2共享内存的新张量x4,
两者的改动可以相互可见, 一个张量上的值的改动会影响到另一个张量.
返回的新张量和计算图相互独立,即新张量和计算图不再关联,
因此也无法进行反向传播计算梯度.即从计算图上把这个张量x2拆
卸detach下来,非常形象.
detach_()
解释说明: detach()的原地操作版本,功能和detach()类似.
比如x4 = x2.detach_(),其实x2和x4是同一个对象,返回的是self,
x2和x4具有相同的id()值.
copy.copy
例子
a = torch.tensor([[1.,2.,3.],[4.,5.,6.]], requires_grad=True)
print(a)
b=a.detach()
print(b)
"""
tensor([[1., 2., 3.],
[4., 5., 6.]], requires_grad=True)
tensor([[1., 2., 3.],
[4., 5., 6.]])
"""
detach()操作后的tensor与原始tensor共享数据内存,当原始tensor在计算图中数值发生反向传播等更新之后,detach()的tensor值也发生了改变
a = torch.tensor([[1.,2.,3.],[4.,5.,6.]], requires_grad=True)
print(a)
b=a.clone()
print(b)
"""
tensor([[1., 2., 3.],
[4., 5., 6.]], requires_grad=True)
tensor([[1., 2., 3.],
[4., 5., 6.]], grad_fn=<CloneBackward>)
"""
grad_fn=<CloneBackward>表示clone后的返回值是个中间变量,因此支持梯度的回溯。
a = torch.tensor([[1.,2.,3.],[4.,5.,6.]], requires_grad=True)
print(a)
b=a.detach().clone()
print(b)
"""
tensor([[1., 2., 3.],
[4., 5., 6.]], requires_grad=True)
tensor([[1., 2., 3.],
[4., 5., 6.]])
"""
a = torch.tensor([[1.,2.,3.],[4.,5.,6.]], requires_grad=True)
print(a)
b=a.detach().clone().requires_grad_(True)
print(b)
"""
tensor([[1., 2., 3.],
[4., 5., 6.]], requires_grad=True)
tensor([[1., 2., 3.],
[4., 5., 6.]], requires_grad=True)
"""
clone()操作后的tensor requires_grad=True
detach()操作后的tensor requires_grad=False
import torch
torch.manual_seed(0)
x= torch.tensor([1., 2.], requires_grad=True)
clone_x = x.clone()
detach_x = x.detach()
clone_detach_x = x.clone().detach()
f = torch.nn.Linear(2, 1)
y = f(x)
y.backward()
print(x.grad)
print(clone_x.requires_grad)
print(clone_x.grad)
print(detach_x.requires_grad)
print(clone_detach_x.requires_grad)
'''
输出结果如下:
tensor([-0.0053, 0.3793])
True
None
False
False
'''