pytorch requires_grad = True的意思

计算图通常包含两种元素,一个是 tensor,另一个是 Function。张量 tensor 不必多说,但是大家可能对 Function 比较陌生。这里 Function 指的是在计算图中某个节点(node)所进行的运算,比如加减乘除卷积等等之类的,Function 内部有 forward() 和 backward() 两个方法,分别应用于正向、反向传播。

当我们创建一个张量 (tensor) 的时候,如果没有特殊指定的话,那么这个张量是默认是不需要求导的。们在训练一个网络的时候,我们从 DataLoader 中读取出来的一个 mini-batch 的数据,这些输入默认是不需要求导的,其次,网络的输出我们没有特意指明需要求导吧,Ground Truth 我们也没有特意设置需要求导吧。这么一想,哇,那我之前的那些 loss 咋还能自动求导呢?其实原因就是上边那条规则,虽然输入的训练数据是默认不求导的,但是,我们的 model 中的所有参数,它默认是求导的,这么一来,其中只要有一个需要求导,那么输出的网络结果必定也会需要求的。来看个实例:

input = torch.randn(8, 3, 50, 100)
print(input.requires_grad)
# False

net = nn.Sequential(nn.Conv2d(3, 16, 3, 1),
                    nn.Conv2d(16, 32, 3, 1))
for param in net.named_parameters():
    print(param[0], param[1].requires_grad)
# 0.weight True
# 0.bias True
# 1.weight True
# 1.bias True

output = net(input)
print(output.requires_grad)
# True

在写代码的过程中,不要把网络的输入和 Ground Truth 的 requires_grad 设置为 True。虽然这样设置不会影响反向传播,但是需要额外计算网络的输入和 Ground Truth 的导数,增大了计算量和内存占用不说,这些计算出来的导数结果也没啥用。因为我们只需要神经网络中的参数的导数,用来更新网络,其余的导数都不需要。

 

原文链接:https://zhuanlan.zhihu.com/p/67184419

posted on 2022-03-18 17:05  啥123  阅读(1500)  评论(0编辑  收藏  举报