PRML-5.3.1 误差函数的反向传播,导数的计算
书上的图5.7介绍了神经网络的结构
但是图过于简单,对于推导公式很不利,很难理解,我对原图做了一些修改和扩展,方便大家理解
首先看下图上的一些标记说明
- \(1.共三层神经元,i层(共I个神经元),j层(共J个神经元),k层(共K个神经元),可以理解为i层是输入层,k层是输出层,j层是隐藏层\)
- \(2.所有上标记号指的是哪一层,比如z_1^{j},表示是第j层的第一个神经元\)
- \(3.所谓神经元具体表示的是一个数,比如z_1^{j} 表示的是某个具体的数值\)
- \(4.w代表的是权重,上标指的是哪两层之间的权重,下标指的是对应的这两层的具体哪个编号的神经元,比如w_{12}^{ji}指的是z_2^{{i}}到z_1^{{j}}两个神经元之间的权重\)
- \(5.从一层神经元到下一层神经元,都要经过一次线性变换,公式5.48,j到k层的a的写法已经标示在图上\)
- \(6.经过线性变化a后,还有做一次激活函数h,公式5.49,标示如图\)
开始推导
\(现在的目标时对每个w求出梯度,类似线性回归那样求出梯度,用梯度下降法求解\)
\(假设现在的问题目标是对w_{12}^{kj}求导\)
\(即求\frac{\partial E_n}{w_{12}^{kj}}\)
\(根据公式5.50的链式法则\color{red}{注意,这里用的是线性激活函数,如果是sigmoid激活函数,这里的链式法则要增加对激活函数h的链式求导}\)
$ \frac{\partial E_n}{\partial w_{ji}} = \frac{\partial E_n}{\partial a_j}\frac{\partial a_j}{\partial w_{ji}} \tag{5.50} $
\(\frac{\partial E_n}{w_{12}^{kj}}=\frac{\partial E_n}{\partial a_1^{(k)}}\frac{\partial a_1^{(k)}}{\partial w_{12}^{kj}}\)
\(引入记号(公式5.51)\)
$ \delta_j \equiv \frac{\partial E_n}{\partial a_j} \tag{5.51} $
\(记\delta_1^{k} \equiv \frac{\partial E_n}{\partial a_1^{(k)}},\delta通常被用来标记误差\)
\(然后链式法则后面那一项根据公式\)
$ \frac{\partial a_j}{\partial w_{ji}} = z_i \tag{5.52} $
\(有\frac{\partial a_1^{(k)}}{\partial w_{12}^{kj}}=\frac{\sum\limits_{n=1}^{J}w_{1n}^{kj}z_1^{(j)}}{\partial w_{12}^{kj}}=z_1^{(j)}\)
继续书上的推导,把式(5.51)(5.52)代入式(5.50),可以得到
$ \frac{\partial E_n}{\partial w_{ji}} = \delta_jz_i \tag{5.53} $
\(这里我们得到\)
\(\frac{\partial E_n}{w_{12}^{kj}}=\delta_1^{k}z_2^{(j)}\)
\(这里k层的\delta 值是容易算的,有公式(这个公式推导不展开了,可以看下标准链接函数4.3.6章节,和公式5.18),结论就是只要是指数分布都有这个公式,比如高斯分布-平方和误差函数,二项分布-sigmoid函数,多项式分布-softmax函数\)
$ \delta_k = y_k - t_k \tag{5.54} $
\(走到这里我们已经可以计算所有k层的w的梯度了\)
\(然后往前推,计算j层的w的梯度\)
\(引入链式法则,公式5.50\)
$ \frac{\partial E_n}{\partial w_{ji}} = \frac{\partial E_n}{\partial a_j}\frac{\partial a_j}{\partial w_{ji}} \tag{5.50} $
\(以计算w_{12}^{ji}为例,为\)
\(\frac{\partial E_n}{\partial w_{12}^{ji}} = \frac{\partial E_n}{\partial a_1^{(j)}}\frac{\partial a_1^{(j)}}{\partial w_{12}^{ji}}\)
\(上链式法则的后面容易计算,根据公式 5.52\)
\(\frac{\partial a_1^{(j)}}{\partial w_{12}^{ji}}=z_2^{(i)}\)
\(前一项\)
\(\frac{\partial E_n}{\partial a_1^{(j)}}\)
\(再次使用链式法则,有公式5.55\)
$ \delta_j \equiv \frac{\partial E_n}{\partial a_j} = \sum\limits_k\frac{\partial E_n}{\partial a_k}\frac{\partial a_k}{\partial a_j} \tag{5.55} $
\(也就是误差对j层的梯度,通过链式法则,引入k层\)
\(引入k层这一步的想法也是很自然的,因为神经网络本来就是不断“前馈”产生的,参看公式5.9\)
\(在我们的推导中5.55改写成如下\)
\(\frac{\partial E_n}{\partial a_1^{(j)}} = \sum\limits_{n=1}^{K}\frac{\partial E_n}{\partial a_n^{(k)}}\frac{\partial a_n^{(k)}}{\partial a_1^{(j)}}\)
\(注意一下,这里从误差函数E_n到a_1^{(j)} 中间会经过所有的k层的神经元,所以这里是一个加和操作\)
\(=\frac{\partial E_n}{\partial a_1^{(k)}}\frac{\partial a_1^{(k)}}{\partial a_1^{(j)}}+\frac{\partial E_n}{\partial a_2^{(k)}}\frac{\partial a_2^{(k)}}{\partial a_1^{(j)}}+...+\frac{\partial E_n}{\partial a_K^{(k)}}\frac{\partial a_K^{(k)}}{\partial a_1^{(j)}}\)
因为有5.51式,所有项的左半边记为\(\delta_1^{(k)}\),因为5.48,右半边记为\(\frac{\partial \sum\limits_{m=1}^{J} w_{nm}^{kj}z_m^{(j)}}{\partial a_1^{(j)}}\)
\(=\delta_1^{(k)}\frac{\partial \sum\limits_{m=1}^{J} w_{1m}^{kj}z_m^{(j)}}{\partial a_1^{(j)}}+\delta_2^{(k)}\frac{\partial \sum\limits_{m=1}^{J} w_{2m}^{kj}z_m^{(j)}}{\partial a_1^{(j)}}+...+\delta_K^{(k)}\frac{\partial \sum\limits_{m=1}^{J} w_{Km}^{kj}z_m^{(j)}}{\partial a_1^{(j)}}\)
整理合并
\(=\sum\limits_{n=1}^{K}\delta_n^{(k)} \frac{\partial \sum\limits_{m=1}^{J} w_{nm}^{kj}z_m^{(j)}}{\partial a_1^{(j)}}\)
套用5.49式
\(=\sum\limits_{n=1}^{K}\delta_n^{(k)} \frac{\partial \sum\limits_{m=1}^{J} w_{nm}^{kj}z_m^{(j)}}{\partial a_1^{(j)}}\)
\(根据图中的公式有z_J^{j}=h(a_J^{(j)})\)
\(=\sum\limits_{n=1}^{K}\delta_n^{(k)} \frac{\partial \sum\limits_{m=1}^{J} w_{nm}^{kj}h(a_m^{(j)})}{\partial a_1^{(j)}}\)
\(这里只有m=1的时候偏导有意义\)
\(=\sum\limits_{n=1}^{K}\delta_n^{(k)} \frac{\partial w_{n1}^{kj}h(a_1^{(j)})}{\partial a_1^{(j)}}\)
\(=\sum\limits_{n=1}^{K}\delta_n^{(k)} w_{n1}^{kj}h'(a_1^{(j)})\)
\(=h'(a_1^{(j)})\sum\limits_{n=1}^{K}\delta_n^{(k)} w_{n1}^{kj}\)
得到了公式 5.56
$ \delta_j = h'(a_j)\sum\limits_kw_{kj}\delta_k \tag{5.56} $
总结算法
于是,反向传播算法可以总结如下:
- 对于网络的一个输入向量$ _n$,使用式(5.48)(5.49)进行正向传播,找到所有隐含单元和输出单元的激活。
- 使用式(5.54)计算所有输出单元的\(\delta_k\)。
- 使用式(5.56)反向传播\(\delta\),获得网络中所有隐含单元的\(\delta_j\)。
- 使用式(5.53)计算导数。