tensor.clone()和tensor.detach()

Posted on 2021-09-10 17:08  foghorn  阅读(599)  评论(0编辑  收藏  举报

1 tensor.clone()

返回原tensor的拷贝,返回的新的tensor和原来的tensor具有同样的大小和数据类型
情况一:
若原tensor的requires_grad = True,clone()返回的是中间节点,梯度会流向原tensor,即返回的tensor的梯度会叠加到原来的tensor上。

>>>import torch
>>>a = torch.tensor(1.0, requires_grad=True)
>>>b = a.clone()
>>>id(a), id(b)
(2892334894104, 2892334859464)  # 表明a和b不是同一个对象
>>>a.requires_grad, b.requires_grad
(True, True)  # 两者的requires_grad都是True
>>>c = a * 2
>>>c.backward()
>>>a.grad
tensor(2.)
>>>d = b * 3
>>>d.backward()
>>>b.grad  # b的梯度值为None
>>>a.grad
tensor(5.)  # b的梯度叠加在a上

情况二:
原tensor的requires_grad = False

>>>import torch
>>>a = torch.tensor(1.0, requires_grad=False)
>>>b = a.clone()
>>>id(a), id(b)
(2892334894104, 2892334859464)  # 表明a和b不是同一个对象
>>>a.requires_grad, b.requires_grad
(False, False)  # 两者的requires_grad都是True
>>>b.requires_grad_()
>>>c = b * 2
>>>c.backward()
>>>b.grad
tensor(2.)
>>>a.grad  # None

tensor.detach()

从计算图中脱离出来。返回一个新的tensor,新的tensor和原来的tensor共享数据内存,但不涉及梯度计算。

tensor.clone().detach 和 tensor.detach.clone()

两者结果是一样的。

Copyright © 2024 foghorn
Powered by .NET 9.0 on Kubernetes