ARTS-S pytorch中backward函数的gradient参数作用

导数偏导数的数学定义

参考资料1和2中对导数偏导数的定义都非常明确.导数和偏导数都是函数对自变量而言.从数学定义上讲,求导或者求偏导只有函数对自变量,其余任何情况都是错的.但是很多机器学习的资料和开源库都涉及到标量对向量求导.比如下面这个pytorch的例子.

import torch
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x ** 2 + 2
z = torch.sum(y)
z.backward()
print(x.grad)

简单解释下,设x=[x1,x2,x3],则

z=x12+x22+x32+6

zx1=2x1

zx2=2x2

zx3=2x3

x1=1.0,x2=2.0,x3=3.0代入就可以得到

(zx1,zx2,zx3)=(2x1,2x2,2x3)=(2.0,4.0,6.0)

结果是和pytorch的输出是一样的.反过来想想,其实所谓的"标量对向量求导"本质上是函数对各个自变量求导,这里只是把各个自变量看成一个向量.和数学上的定义并不矛盾.

backward的gradient参数作用

现在有如下问题,已知

y1=x1x2x3

y2=x1+x2+x3

y3=x1+x2x3

A=f(y1,y2,y3)

其中函数f(y1,y2,y3)的具体定义未知,现在求

Ax1=?

Ax2=?

Ax3=?

根据参考资料2中讲的多元复合函数的求导法则.

Ax1=Ay1y1x1+Ay2y2x1+Ay3y3x1

Ax2=Ay1y1x2+Ay2y2x2+Ay3y3x2

Ax3=Ay1y1x3+Ay2y2x3+Ay3y3x3

上面3个等式可以写成矩阵相乘的形式.如下

(1)[Ax1,Ax2,Ax3]=[Ay1,Ay2,Ay3][y1x1y1x2y1x3y2x1y2x2y2x3y3x1y3x2y3x3]

其中

[y1x1y1x2y1x3y2x1y2x2y2x3y3x1y3x2y3x3]

叫作雅可比(Jacobian)式.雅可比式可以根据已知条件求出.现在只要知道[Ay1,Ay2,Ay3]的值,哪怕不知道f(y1,y2,y3)的具体形式也能求出来[Ax1,Ax2,Ax3]. 那现在的现在的问题是:
怎么样才能求出

[Ay1,Ay2,Ay3]

答案是由pytorch的backward函数的gradient参数提供.这就是gradient参数的作用. 参数gradient能解决什么问题,有什么实际的作用呢?说实话,因为我才接触到pytorch,还真没有见过现实中怎么用gradient参数.但是目前可以通过数学意义来理解,就是可以忽略复合函数某个位置之前的所有函数 的具体形式,直接给定一个梯度来求得对各个自变量的偏导.
上面各个方程用代码表示如下所示:

# coding utf-8
import torch

x1 = torch.tensor(1, requires_grad=True, dtype=torch.float)
x2 = torch.tensor(2, requires_grad=True, dtype=torch.float)
x3 = torch.tensor(3, requires_grad=True, dtype=torch.float)
y = torch.randn(3)
y[0] = x1 * x2 * x3
y[1] = x1 + x2 + x3
y[2] = x1 + x2 * x3
x = torch.tensor([x1, x2, x3])
y.backward(torch.tensor([0.1, 0.2, 0.3], dtype=torch.float))
print(x1.grad)
print(x2.grad)
print(x3.grad)

按照上用的推导方法

[Ax1,Ax2,Ax3]=[Ay1,Ay2,Ay3][x2x3x1x3x1x21111x3x2]=[0.1,0.2,0.3][632111132]=[1.1,1.4,1.0]

和代码的运行结果是一样的.

参考资料

  1. 同济大学数学系,高等数学第七版上册,高等教育出版社,p75-76, 2015.
  2. 同济大学数学系,高等数学第七版下册,高等教育出版社,p78-80,p88-91, 2015.
  3. Calculus,Thirteenth Edition,p822, 2013.
  4. 详解Pytorch 自动微分里的(vector-Jacobian product)
  5. PyTorch 的 backward 为什么有一个 grad_variables 参数?)

posted on   荷楠仁  阅读(6082)  评论(5编辑  收藏  举报

编辑推荐:
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
阅读排行:
· DeepSeek 开源周回顾「GitHub 热点速览」
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· AI与.NET技术实操系列(二):开始使用ML.NET
· 单线程的Redis速度为什么快?

导航

统计

点击右上角即可分享
微信分享提示