在介绍反向传播算法前,先看看矩阵微分的概念。
矩阵微积分
为了书写简便,我们通常把单个函数对多个变量或者多元函数对单个变量的偏导数写成向量和矩阵的形式,使其可以被当成一个整体处理.
标量关于向量的偏导数
对于 M 维向量 x∈RM 和函数 y=f(x)∈R , 则 y 关 于 x 的偏导数为:
∂y∂x=[∂y∂x1,⋯,∂y∂xM]⊤
向量关于标量的偏导数
对于标量 x∈R 和函数 y=f(x)∈RN , 则 y 关于 x 的 偏导数为:
∂y∂x=[∂y1∂x,⋯,∂yN∂x]
向量关于向量的偏导数
对于 M 维向量 x∈RM 和函数 y=f(x)∈RN , 则 f(x) 关于 x 的偏导数为:
∂f(x)∂x=⎡⎢
⎢
⎢
⎢
⎢⎣∂y1∂x1⋯∂yN∂x1⋮⋱⋮∂y1∂xM⋯∂yN∂xM⎤⎥
⎥
⎥
⎥
⎥⎦
前向传播算法
根据之前的介绍,第l层的输出为:
a(l)=fl(W(l)a(l−1)+b(l))
其中:
z(l)=W(l)a(l−1)+b(l)
反向传播算法
假设采用随机梯度下降进行神经网络参数学习, 给定一个样本 (x,y) , 将其输入到神经网络模型中, 得到网络输出为 ^y . 假设损失函数为 L(y,^y) , 要进行参数学习就需要计算损失函数关于每个参数的导数.
不失一般性, 对第 l 层中的参数 W(l) 和 b(l) 计算偏导数. 因为 ∂L(y,^y)∂W(l) 的计算 涉及向量对矩阵的微分, 十分繁琐, 因此我们先计算 L(y,^y) 关于参数矩阵中每个元素的偏导数 ∂L(y,^y)∂w(l)ij . 根据链式法则:
∂L(y,^y)∂w(l)ij=∂z(l)∂w(l)ij∂L(y,^y)∂z(l),∂L(y,^y)∂b(l)=∂z(l)∂b(l)∂L(y,^y)∂z(l)
上面两个公式中的第二项都是目标函数关于第 l 层的神经元 z(l) 的偏导数, 称为误差项,可以一次计算得到.且记δ(l)≜∂L(y,^y)∂z(l).它的大小间接反应了其神经元对整个网络能力的贡献. 这样我们只需要计算三个偏导数, 分别为 ∂z(l)∂w(l)ij,∂z(l)∂b(l) 和 ∂L(y,^y)∂z(l) .
下面分别来计算这三个偏导数
因 z(l)=W(l)a(l−1)+b(l) ,且w(l)ij为l−1层的第j个元素到l层的第i个元素连接的权重,所以:
∂z(l)∂w(l)ij=⎡⎣∂z(l)1∂w(l)ij,⋯,∂z(l)i∂w(l)ij,⋯,∂z(l)Ml∂w(l)ij⎤⎦=⎡⎢⎣0,⋯,∂(w(l)i:a(l−1)+b(l)i)∂w(l)ij,⋯,0⎤⎥⎦=[0,⋯,a(l−1)j,⋯,0]∈R1×Ml
∂z(l)∂b(l)=IMl
其中,IMl是Ml×Ml的单位阵。
因z(l+1)=W(l+1)a(l)+b(l+1),a(l)=fl(z(l))所以:
∂z(l+1)∂a(l)=(W(l+1))⊤
∂a(l)∂z(l)=∂fl(z(l))∂z(l)=diag(f′l(z(l)))
根据链式法则得:
δ(l)≜∂L(y,^y)∂z(l)=∂a(l)∂z(l)⋅∂z(l+1)∂a(l)⋅∂L(y,^y)∂z(l+1)=diag(f′l(z(l)))⋅(W(l+1))⊤⋅δ(l+1)
从上式可以看出,第𝑙 层的误差项可以通过第𝑙 + 1层的误差项计算得到,这就是误差的反向传播。
回到最初
由上面的计算结果可以得到:
∂L(y,^y)∂w(l)ij=[0,⋯,a(l−1)j,⋯,0][δ(l)1,⋯,δ(l)i,⋯,δ(l)Ml]⊤=δ(l)ia(l−1)j
根据下标关系可得:
∂L(y,^y)∂W(l)=δ(l)(a(l−1))⊤∈RMl×Ml−1
等式右边的两个向量尺寸分别为Ml×1和1×Ml−1.
同理:
∂L(y,^y)∂b(l)=δ(l)
计算误差
在计算出每一层的误差项之后,我们就可以得到每一层参数的梯度.因此,使用误差反向传播算法的前馈神经网络训练过程可以分为以下三步:
- 前馈计算每一层的净输入z(l)和激活值a(l),直到最后一层;
- 反向传播计算每一层的误差项δ(l);
- 计算每一层参数的偏导数,并更新参数.
网络训练过程
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)