pytorch 深度学习之自动微分

自动微分

深度学习框架通过自动计算导数,即自动微分(automatic differentiation)来加快求导。 实际中,根据我们设计的模型,系统会构建一个计算图(computational graph), 来跟踪计算是哪些数据通过哪些操作组合起来产生输出。 自动微分使系统能够随后反向传播梯度。 这里,反向传播(backpropagate)意味着跟踪整个计算图,填充关于每个参数的偏导数。

例如,相对 y=2xTx 关于向量 x 求导,首先创建一个变量 x 并为其分配一个初始值:

import torch x = torch.arange(4.0) x
tensor([0., 1., 2., 3.])

需要一个地方来存储梯度。重要的是,我们不会在每次对一个参数求导时都分配新的内存。 因为我们经常会成千上万次地更新相同的参数,每次都分配新的内存可能很快就会将内存耗尽。 注意,一个标量函数关于向量 x 的梯度是向量,并且与 x 具有相同的形状。

x.requires_grad_(True) # 等价于x=torch.arange(4.0,requires_grad=True) x.grad # 默认值是None
y = 2 * torch.dot(x,x) y
tensor(28., grad_fn=<MulBackward0>)

x 是一个长度为 4 的向量,计算 xx 的点积,得到了我们赋值给y的标量输出。 接下来,我们通过调用反向传播函数来自动计算 y 关于 x 每个分量的梯度,并打印这些梯度:

y.backward() x.grad
tensor([ 0., 4., 8., 12.])

函数 y=2xTx 关于 x 的梯度应该是 4x,验证:

x.grad == 4 * x
tensor([True, True, True, True])
# 在默认情况下,PyTorch会累积梯度,我们需要清除之前的值 x.grad.zero_() y = x.sum() y.backward() x.grad
tensor([1., 1., 1., 1.])

非标量变量的反向传播

y 不是标量时,向量 y 关于向量 x 的导数的最自然解释是一个矩阵。对于高阶和高维的 yx,求导的结果可以是一个高阶张量。

# 对非标量调用backward需要传入一个gradient参数,该参数指定微分函数关于self的梯度。 # 在我们的例子中,我们只想求偏导数的和,所以传递一个1的梯度是合适的 x.grad.zero_() y = x * x # 等价于y.backward(torch.ones(len(x))) y.sum().backward() x.grad
tensor([0., 2., 4., 6.])

分离计算

假设 y 是作为 x 的函数计算的,而 z 则是作为 yx 的函数计算的。 想象一下,我们想计算 z 关于 x 的梯度,但由于某种原因,我们希望将 y 视为一个常数, 并且只考虑到 xy 被计算后发挥的作用。
在这里,我们可以分离 y 来返回一个新变量 u,该变量与 y 具有相同的值,但丢弃计算图中如何计算 y 的任何信息。 换句话说,梯度不会向后流经 ux。 因此,下面的反向传播函数计算 z=ux 关于 x 的偏导数,同时将 u 作为常数处理, 而不是 z=xxx 关于 x 的偏导数。

x.grad.zero_() y = x * x u = y.detach(); z = u * x z.sum().backward() x.grad == u
tensor([True, True, True, True])

由于记录了 y 的计算结果,我们可以随后在 y 上调用反向传播, 得到 y=xx 关于的 x 的导数,即 2x

x.grad.zero_() y.sum().backward() x.grad == 2 * x
tensor([True, True, True, True])

Python 控制流的梯度计算

在下面的代码中,while 循环的迭代次数和 if 语句的结果都取决于输入 a 的值:

def f(a): b = a * 2 while b.norm() < 1000: b *= 2 if b.sum() > 0: c = b else: c = 100 * b return c a = torch.randn(size=(),requires_grad=True) d = f(a) d.backward() a.grad == d / a
tensor(True)

__EOF__

本文作者刘皇叔
本文链接https://www.cnblogs.com/xiaojianliu/p/16160210.html
关于博主:评论和私信会在第一时间回复。或者直接私信我。
版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!
声援博主:如果您觉得文章对您有帮助,可以点击文章右下角推荐一下。您的鼓励是博主的最大动力!
posted @   刘-皇叔  阅读(315)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
历史上的今天:
2020-04-18 Sphinx + GitHub + ReadtheDocs 创建电子书
点击右上角即可分享
微信分享提示