Pytorch深度学习:自动微分
自动微分
根据设计好的模型,系统会构建一个计算图(computational graph), 来跟踪计算是哪些数据通过哪些操作组合起来产生输出。 自动微分使系统能够随后反向传播梯度。 反向传播(backpropagate)意味着跟踪整个计算图,填充关于每个参数的偏导数。
我们先看一个简单的例子:
函数y是一个标量函数
有一个列向量\(x=(x_1,x_2,x_3,x_4)^T\),函数:
那么我们想求解y关于x的梯度,怎么求呢?
import torch
x = torch.arange(4.0)
x.requires_grad_(True) # 等价于x=torch.arange(4.0,requires_grad=True)
# 此时x.grad默认值是None
y = 2 * torch.dot(x, x)
y.backward() #反向传播,计算y关于x的每个分量的梯度
print(x.grad) #梯度保存在这里,结果为:tensor([ 0., 4., 8., 12.])
如果我们还想计算x的另一个函数,那么先对x的梯度清0:
# 在默认情况下,PyTorch会累积梯度,我们需要清除之前的值
x.grad.zero_()
y = x.sum()
y.backward()
# x.grad为tensor([1., 1., 1., 1.])
函数y是向量时
当函数y不是标量时,向量y
关于向量x
的导数是一个矩阵。 求导的结果是一个高阶张量。
但一般我们调用向量的反向传播时,通常是计算一个批量里每个样本的偏导数之和。
# 对非标量调用backward需要传入一个gradient参数,该参数指定微分函数关于self的梯度。
# 本例只想求偏导数的和,所以传递一个1的梯度是合适的
x.grad.zero_()
y = x * x #y=(x1^2,...,x4^2),是一个向量
y.sum().backward()
x.grad
上面代码里y.sum().backward()
也可以等价地写成:
y.backward(torch.ones(len(x)))
backward
函数接受了一个叫做gradient
的参数,当y是标量时不需要该参数,但如果是向量则必须传入参数gradient
。该参数的作用是:
当\(\textbf{y}\)对\(\textbf{x}\)求导时,结果是一个梯度矩阵,为:
当获取x的梯度时,有:
参考:[backward函数中gradient参数的一些理解](https://www.cnblogs.com/meitiandouyaokaixin/p/16339669.html#:~:text=backward函数中gradient参数的一些理解 当标量对向量求导时不需要该参数,但当向量对向量求导时,若不加上该参数则会报错,显示“grad can be implicitly created only,for scalar outputs”,对该gradient参数解释如下。 当 y 对 x 求导时,结果为梯度矩阵,数学表达如下:)
分离计算
如果我们希望把某些计算移动到计算图之外,那么可以采用分离计算来实现:
x = torch.arange(4.0)
y = x * x # y = (x1^2,...,x4^2) = (1,4,9,16)
u = y.detach() # u = (1,4,9,16)
z = u * x # z = (x1, 4*x2, 9*x3, 16*x4)
z.sum().backward()
# 结果:x.grad 和 u 相等
u就是y移除计算图的变量,只是一个常数张量。
注意,此时u.requires_grad
为False,也可以通过u.requires_grad_(True)
将其设置为叶子节点,不过它就相当于一个刚创建的张量,之前的计算图和它没有关系。
注意事项
- 当尝试输出非叶子节点的梯度时
对于如下代码:
x = torch.arange(4.0, requires_grad=True)
y = x * x
z = y * x
z.sum().backward()
print(y.grad)
# 结果为None
当尝试输出叶子节点y
的梯度值时,会报出warning,警告不要获取非叶子节点的梯度,并且返回None。
- 当尝试两次调用backward函数时:
对于上面的代码,我们连续调用:
z.sum().backward()
z.sum().backward()
会报错,报错信息显示Saved intermediate values of the graph are freed when you call .backward() or autograd.grad().
- 好文要顶
一文解释 PyTorch求导相关 (backward, autograd.grad) - 知乎 (zhihu.com)这篇文章写的太好了。