反向传播理论推导

希望本文成为你见过的反向传播理论中最易理解的解释和最简洁形式的公式推导 😃

反向传播是上世纪80年代提出的训练神经网络的一种方法,在每次迭代训练时修改对每个神经元输入的权值,来达到最后一层的输出与期望的输出的总误差最小的目的。反向传播算法可以说是梯度下降在链式法则中的应用。

反向传播与梯度下降

Q:为什么会提出反向传播算法,直接应用梯度下降(Gradient Descent)不行吗?

梯度下降可以应对带有明确求导函数的情况(容易求解析解的情况),比如逻辑回归(Logistic Regression),我们可以把它看做没有隐层的网络;但对于多层神经网络,损失函数对前边的参数计算解析解导数,将十分复杂。而如果要数值解(函数的导数在某个点的数值),就很容易计算,直接根据导数的定义就能得到。

CwjC(w+ϵej)C(w)ϵ

式中 cost function 看作权重的函数 C=C(w), 忽略了偏置bias(其导数计算与w类似)。其中ϵ是一个大于零的很小的正数,ej 表示单位向量(方向坐标)。

通过这个导数的定义来计算梯度,实现比较简单,能够直接应用梯度下降来更新模型参数。然而网络参数量通常都很大,这种求解方式的计算量也会非常大,原因如下:

对于每一个权重wj,计算梯度Cwj时需要计算一次C(w+ϵej),这需要一次完整的前向传播才能得到。假如有一千万个参数,则完整执行一次梯度下降需要一千万+1次前向传播,+1次前向传播是原本计算损失值 C(w)的那次。这样的计算量显然不可接受,而反向传播通过链式求导法则,仅需要反向传播一次,并且反向传播与正向传播的计算量相当(反向时采用权重矩阵转置的乘法)。因此反向传播算法可以说是梯度下降在链式法则中的应用。

Q:RNN为什么要基于时间步骤反向传播,直接梯度下降不行吗?

对于RNN训练也是类似的逻辑,通过反向传播来计算梯度。由于RNN具有循环结构,可看作是展开的权重共享的多层神经网络。按照常规的链式求导思路可推导出基于时间步骤反向传播(BPTT)方法,除了有递归结构之外与常规的反向传播无异。与常规反向传播一样,存在梯度消失或梯度爆炸问题。 为了计算方便性和数值稳定性的需要,常使用截断的方法,包括:规则截断和随机截断。截断导致该模型主要侧重于短期影响,而不是长期影响。

反向传播推导

本文以多层感知机为例对反向传播公式进行推导, 但不局限于某种激活函数或损失函数.

先上精简版的图示, 帮助解释:

BP图示

BP图示

反向传播的目的是更新神经元参数,而神经元参数正是 z=wx+b 中的 (w,b).
对参数的更新为:利用损失值loss对参数求导数, 并沿着负梯度方向进行更新。

w=wηwb=bηb

运用链式法则先求误差对当前神经元的线性加权的结果 z 的梯度,再求对当前神经元输入连接的权重 (w,b) 的梯度.

(1)Ezil=j[Ezjl+1zjl+1ailailzil](2)=j[Ezjl+1wjil+1σ(zil)]

其中σ()为激活函数, σ() 为激活函数的导数.
ajl+1,ail 分别代表当前第l+1层第j个神经元与前一层(第l层)第i个神经元的激活输出值.

(3)ail=σ(zll)(4)zjl+1=wjil+1ail+bjl+1

注意的是式子和图示中的 E 是单个训练样本的损失, 而式子中的求和符号 是对应层的神经元求和. 不要和代价损失的求和混淆了(cost function 是对所有训练样本的损失求平均值).

从上式中,你应该能看出来损失对每层神经元的zi的梯度是个从后往前的递推关系式. 即:

δil=j[δjl+1wjil+1σ(zil)]

(w,b) 梯度

我们想要的(w,b)的梯度也能够立刻计算出来:
l层的第i个神经元与第l+1层的第j个神经元的权重:

(5)Ewjil+1=Ezilzilwjil+1=δjl+1ail(6)Ebjl+1=δjl+1

往前一层(第 i 层)的权重:

(7)Ewihl=δilahl1(8)Ebil=δil

就这样, 梯度从后往前一层一层传播计算, 每层的梯度等于后一层的梯度与当前层的局部偏导数的乘积, 这正是链式法则的简单应用。

矩阵乘

对上面的公式稍加变化便能改成矩阵乘法形式:

(9)δil=j[δjl+1wjil+1σ(zil)](10)[δ]j×1=[(w+1)k×jT×(δ+1)k×1]j×1σ(z)j×1(11)[Ew+1]j×i=[δ+1]j×1[a]1×i

其中 是在反向传播中常用的哈达玛乘积(Hadamard product),对两个相同维度的矩阵按元素相乘得到同维度的输出。

参考

posted @   康行天下  阅读(2883)  评论(0编辑  收藏  举报
编辑推荐:
· 理解Rust引用及其生命周期标识(下)
· 从二进制到误差:逐行拆解C语言浮点运算中的4008175468544之谜
· .NET制作智能桌面机器人:结合BotSharp智能体框架开发语音交互
· 软件产品开发中常见的10个问题及处理方法
· .NET 原生驾驭 AI 新基建实战系列:向量数据库的应用与畅想
阅读排行:
· 2025成都.NET开发者Connect圆满结束
· Ollama本地部署大模型总结
· langchain0.3教程:从0到1打造一个智能聊天机器人
· 在 VS Code 中,一键安装 MCP Server!
· 用一种新的分类方法梳理设计模式的脉络
点击右上角即可分享
微信分享提示