13-反向传播法求梯度
反向传播法求梯度
利用计算图求梯度是一种比较方便又快速的方法,如何利用计算图求梯度?先回忆一下计算图:
以 z = x 2 + y 2 z=x^2+y^2 z=x2+y2 为例:
-
计算图以箭头和节点构成,正向传播时,得到的结果是 z = x 2 + y 2 z=x^2+y^2 z=x2+y2
-
反向传播时,得到的结果是: ∂ L ∂ z × 2 x \frac{\partial L}{\partial z}\times 2x ∂z∂L×2x 和 ∂ L ∂ z × 2 y \frac{\partial L}{\partial z}\times 2y ∂z∂L×2y
-
仔细一看,令 ∂ L ∂ z = 1 \frac{\partial L}{\partial z}=1 ∂z∂L=1 不就得到梯度为 ( 2 x , 2 y ) (2x,2y) (2x,2y) 了吗!(注意这里的x, y是正向传播的x和y)
为什么令 ∂ L ∂ z = 1 \frac{\partial L}{\partial z}=1 ∂z∂L=1 就可以得到梯度了呢?实际上在计算图中,L才是输出函数,而z只是中间变量,但在这里,z也是输出函数,所以 L=z, 因此有 ∂ L ∂ z = 1 \frac{\partial L}{\partial z}=1 ∂z∂L=1 -
认识到这点,现在就可以写代码实现求梯度了,这里用一个类来实现:
class SqrtWithAdd:
def __init__(self):
self.x = None
self.y = None
def forward(self, x, y):
self.x = x
self.y = y
out = x**2 + y**2
return out
def backward(self, dout=1):
dx = dout * 2*self.x
dy = dout * 2*self.y
return dx,dy #(dx, dy)就是梯度
现在利用这个类来求一下在点 (2, 3) 处的梯度:
sqrt_with_add = SqrtWithAdd() #实例化类
sqrt_with_add.forward(2,3) #先进行正向传播
grad = sqrt_with_add.backward() #求梯度
print(grad)
#输出:
(4, 6)
经手算验证,结果正确。
梯度下降求最小值
既然用反向传播的方法求出了梯度值,那么现在就想用这个方法结合梯度下降法来求一下函数最小值。
先把公式写一下:
x
=
x
−
η
∂
f
∂
x
y
=
y
−
η
∂
f
∂
y
x = x - \eta \frac{\partial f}{\partial x}\\y = y - \eta \frac{\partial f}{\partial y}
x=x−η∂x∂fy=y−η∂y∂f
之前讲过了,
η
\eta
η 是学习率。
代码如下:
def gradient_descent(init_x,init_y, lr=0.01, step_num=100):
x = init_x
y = init_y
sqrt_with_add = SqrtWithAdd() #创建实例
for i in range(step_num):
sqrt_with_add.forward(x,y) #正向传播
dx,dy = sqrt_with_add.backward() #反向传播求梯度
x -= lr * dx
y -= lr * dy
return x,y
#调用函数求最小值位置
x,y = gradient_descent(init_x= 3,init_y= 4) #设置开始起点(3,4)
print('[{}, {}]'.format(x,y))
输出:
[0.39785866768425965, 0.5304782235790126]
结果可以通过调整学习率 lr
和 学习次数 step_num
来接近最佳位置。
这里举的例子是二维的,可以通过调整上面的类的输入值的维数来使其变为可以求 n维的。
最后总结一下反向传播法求梯度与数值微分法求梯度的区别:
反向传播求梯度实际上用的是解析法来求导,跟自己手算求梯度是一样的,记得求导公式就可以;
而数值微分法求的梯度实际上用的是导数定义法来求
本文来自博客园,作者:aJream,转载请记得标明出处:https://www.cnblogs.com/ajream/p/15383572.html
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人