Back Propagation - Python实现

  • 算法特征
    ①. 统一看待线性运算与非线性运算; ②. 确定求导变量loss影响链路; ③. loss影响链路梯度逐级反向传播.
  • 算法推导
    Part Ⅰ
    以如下简单正向传播链为例, 引入线性运算与非线性运算符号,
    相关运算流程如下,
    $$
    \begin{equation*}
    \begin{split}
    &\text{linear operation } & I^{(l+1)} = W^{(l)}\cdot O^{(l)} + b^{(l)} \\
    &\text{non-linear operation }\quad & O^{(l+1)} = f^{(l+1)}(I^{(l+1)})
    \end{split}
    \end{equation*}
    $$
    其中, $O^{(l)}$、$I^{(l+1)}$、$O^{(l+1)}$分别为第$l$层输出(output)、第$l+1$层输入(input)、第$l+1$层输出, $W^{(l)}$、$b^{(l)}$、$f^{(l+1)}$分别为相关weight、bias及activation function. 对于线性运算, bias可合并至weight, 则
    $$
    \begin{equation*}
    \text{linear operation }\quad I^{(l+1)} = \tilde{W}^{(l)}\cdot \tilde{O}^{(l)} = g^{(l)}(\tilde{O}^{(l)})
    \end{equation*}
    $$
    其中, $\tilde{W}^{(l)} = [W^{(l)\mathrm{T}}, b^{(l)}]^\mathrm{T}$, $\tilde{O}^{(l)}=[O^{(l)\mathrm{T}},1]^\mathrm{T}$.
    对于上述简单正向传播链, 基础影响链路如下,
    $$
    \begin{equation*}
    \begin{split}
    &\tilde{O}^{(l)} &\quad\rightarrow\quad I^{(l+1)} \\
    &\tilde{W}^{(l)} &\quad\rightarrow\quad I^{(l+1)} \\
    &I^{(l+1)} &\quad\rightarrow\quad O^{(l+1)}
    \end{split}
    \end{equation*}
    $$
    根据链式法则, 影响链路起点由求导变量决定, 终点位于Loss处.
    Part Ⅱ
    现以如下神经网络为例, 加以阐述,

    其中, $(x_1, x_2)$为网络输入, $(y_1, y_2)$为网络输出. 统一处理其中线性运算与非线性运算, 并将该网络完全展开,

    现以$w_1^{(0)}$为例, 作为求导变量确定loss影响链路. $w_1^{(0)}$之loss影响链路(红色+绿色箭头)如下,

     

    拆解为基础影响链路, 如下,
    $$
    \begin{equation*}
    \begin{split}
    & w_1^{(0)} &\quad\rightarrow\quad I_1^{(1)} \\
    & I_1^{(1)} &\quad\rightarrow\quad O_1^{(1)} \\
    & O_1^{(1)} &\quad\rightarrow\quad I_1^{(2)} \\
    & O_1^{(1)} &\quad\rightarrow\quad I_2^{(2)} \\
    & I_1^{(2)} &\quad\rightarrow\quad y_1 \\
    & I_2^{(2)} &\quad\rightarrow\quad y_2 \\
    & y_1 &\quad\rightarrow\quad L \\
    & y_2 &\quad\rightarrow\quad L
    \end{split}
    \end{equation*}
    $$
    定义源(source)为存在多个下游分支的节点(如: $O_1^{(1)}$), 定义汇(sink)为存在多个上游分支的节点(如: $L$). 根据链式求导法则, 在影响链路上, 下游变量对源变量求导需要在源变量处求和, 汇变量对上游变量求导无需特殊处理, 即,
    $$
    \begin{equation*}
    \begin{split}
    &\text{source variable $x$: }\quad &\frac{\partial J(u(x), v(x))}{\partial x} &= \frac{\partial J(u(x), v(x))}{\partial u}\cdot\frac{\partial u}{\partial x} + \frac{\partial J(u(x), v(x))}{\partial v}\cdot\frac{\partial v}{\partial x} \\
    &\text{sink variable $J$: }\quad &\frac{\partial J(u(x), v(y))}{\partial x} &= \frac{\partial J(u(x), v(y))}{\partial u}\cdot\frac{\partial u}{\partial x}
    \end{split}
    \end{equation*}
    $$
    由此, 根据梯度沿影响链路的反向传播, 可确定loss对$w_1^{(0)}$之偏导,
    $$
    \begin{equation*}
    \frac{\partial L}{\partial w_1^{(0)}} = \left(\frac{\partial L}{\partial y_1}\cdot\frac{\partial y_1}{\partial I_1^{(2)}}\cdot\frac{\partial I_1^{(2)}}{\partial O_1^{(1)}} + \frac{\partial L}{\partial y_2}\cdot\frac{\partial y_2}{\partial I_2^{(2)}}\cdot\frac{\partial I_2^{(2)}}{\partial O_1^{(1)}}\right)\cdot\frac{\partial O_1^{(1)}}{\partial I_1^{(1)}}\cdot\frac{\partial I_1^{(1)}}{\partial w_1^{(0)}}
    \end{equation*}
    $$
    再以$b_2^{(1)}$为例, 作为求导变量确定loss影响链路. $b_2^{(1)}$之loss影响链路(红色+绿色箭头)如下,

     

    拆解为基础影响链路, 如下,
    $$
    \begin{equation*}
    \begin{split}
    & b_2^{(1)} &\quad\rightarrow\quad I_2^{(2)} \\
    & I_2^{(2)} &\quad\rightarrow\quad y_2 \\
    & y_2 &\quad\rightarrow\quad L
    \end{split}
    \end{equation*}
    $$
    同样, 根据梯度沿影响链的反向传播, 可确定loss对$b_2^{(1)}$之偏导,
    $$
    \begin{equation*}
    \frac{\partial L}{\partial b_2^{(1)}} = \frac{\partial L}{\partial y_2}\cdot\frac{\partial y_2}{\partial I_2^{(2)}}\cdot\frac{\partial I_2^{(2)}}{\partial b_2^{(1)}}
    \end{equation*}
    $$
  • 代码实现
    现以如下简单feed-forward网络为例进行算法实施,
    输入层为$(r, g, b)$, 输出层为$(x,y,lv)$且不取激活函数, 中间隐藏层取激活函数为双曲正切函数$\tanh$. 采用如下损失函数,
    $$
    \begin{equation*}
    L = \sum_i\frac{1}{2}(\bar{x}^{(i)}-x^{(i)})^2 + \frac{1}{2}(\bar{y}^{(i)} - y^{(i)})^2 + \frac{1}{2}(\bar{lv}^{(i)} - lv^{(i)})^2
    \end{equation*}
    $$
    其中, $i$为data序号, $(\bar{x}, \bar{y}, \bar{lv})$为相应观测值. 相关training data采用如下策略生成,
    $$
    \begin{equation*}
    \left\{
    \begin{split}
    x &= r + 2g + 3b \\
    y &= r^2 + 2g^2 + 3b^2 \\
    lv &= -3r - 4g - 5b
    \end{split}
    \right.
    \end{equation*}
    $$
    具体实现如下,
      1 # Back Propagation之实现
      2 # 优化器使用Adam
      3 
      4 import numpy
      5 from matplotlib import pyplot as plt
      6 
      7 
      8 numpy.random.seed(1)
      9 
     10 
     11 # 生成training data
     12 def getData(n=100):
     13     rgbRange = (-1, 1)
     14     r = numpy.random.uniform(*rgbRange, (n, 1))
     15     g = numpy.random.uniform(*rgbRange, (n, 1))
     16     b = numpy.random.uniform(*rgbRange, (n, 1))
     17     x_ = r + 2 * g + 3 * b
     18     y_ = r ** 2 + 2 * g ** 2 + 3 * b ** 2
     19     lv_ = -3 * r - 4 * g - 5 * b
     20     RGB = numpy.hstack((r, g, b))
     21     XYLv_ = numpy.hstack((x_, y_, lv_))
     22     return RGB, XYLv_
     23     
     24     
     25 class BPEx(object):
     26 
     27     def __init__(self, RGB, XYLv_):
     28         self.__RGB = RGB
     29         self.__XYLv_ = XYLv_
     30         
     31         self.__rgb = None           # (1, 3)
     32         self.__O0 = None            # (1, 4)
     33         self.__W0 = None            # (4, 5)
     34         self.__I1 = None            # (1, 5)
     35         self.__O1 = None            # (1, 6)
     36         self.__W1 = None            # (6, 3)
     37         self.__I2 = None            # (1, 3)
     38         self.__xylv_ = None         # (1, 3)
     39         self.__L = None             # scalar
     40         
     41         self.__grad_W0_I1 = None    # 基础影响链路(W0->I1)之梯度
     42         self.__grad_I1_O1 = None    # 基础影响链路(I1->O1)之梯度
     43         self.__grad_O1_I2 = None    # 基础影响链路(O1->I2)之梯度
     44         self.__grad_W1_I2 = None    # 基础影响链路(W1->I2)之梯度
     45         self.__grad_I2_L = None     # 基础影响链路(I2->Loss)之梯度
     46         
     47         self.__gradList_W0_I1 = list()
     48         self.__gradList_I1_O1 = list()
     49         self.__gradList_O1_I2 = list()
     50         self.__gradList_W1_I2 = list()
     51         self.__gradList_I2_L = list()
     52         
     53         self.__init_weights()
     54         self.__bpTag = False        # 是否反向传播之标志
     55         
     56         
     57     def calc_xylv(self, rgb, W0=None, W1=None):
     58         '''
     59         默认在当前W0、W1之基础上进行预测, 实际使用需要配合优化器
     60         '''
     61         if W0 is not None:
     62             self.__W0 = W0
     63         if W1 is not None:
     64             self.__W1 = W1
     65         
     66         self.__calc_xylv(numpy.array(rgb).reshape((1, 3)))
     67         xylv = self.__I2
     68         return xylv
     69         
     70         
     71     def get_W0W1(self):
     72         return self.__W0, self.__W1
     73         
     74         
     75     def calc_JVal(self, W):
     76         self.__bpTag = True
     77         self.__clr_gradList()
     78     
     79         self.__W0 = W[:20, 0].reshape((4, 5))
     80         self.__W1 = W[20:, 0].reshape((6, 3))
     81         
     82         JVal = 0
     83         for rgb, xylv_ in zip(self.__RGB, self.__XYLv_):
     84             self.__calc_xylv(rgb.reshape((1, 3)))
     85             self.__calc_loss(xylv_.reshape((1, 3)))
     86             JVal += self.__L
     87             self.__add_gradList()
     88             
     89         self.__bpTag = False
     90         return JVal
     91         
     92         
     93     def calc_grad(self, W):
     94         '''
     95         此处W仅为统一调用接口
     96         '''
     97         grad_W0 = numpy.zeros(self.__W0.shape)
     98         grad_W1 = numpy.zeros(self.__W1.shape)
     99         for grad_W0_I1, grad_I1_O1, grad_O1_I2, grad_W1_I2, grad_I2_L in zip(self.__gradList_W0_I1,\
    100             self.__gradList_I1_O1, self.__gradList_O1_I2, self.__gradList_W1_I2, self.__gradList_I2_L):
    101             grad_W1_curr = grad_I2_L * grad_W1_I2
    102             grad_W1 += grad_W1_curr
    103             
    104             term0 = numpy.sum(grad_I2_L.T * grad_O1_I2, axis=0)   # (1, 5) source
    105             term1 = term0 * grad_I1_O1
    106             grad_W0_curr = term1 * grad_W0_I1
    107             grad_W0 += grad_W0_curr
    108             
    109         grad = numpy.vstack((grad_W0.reshape((-1, 1)), grad_W1.reshape((-1, 1))))
    110         return grad
    111         
    112         
    113     def init_seed(self):
    114         n = self.__W0.size + self.__W1.size
    115         seed = numpy.random.uniform(-1, 1, (n, 1))
    116         return seed
    117         
    118         
    119     def __add_gradList(self):
    120         self.__gradList_W0_I1.append(self.__grad_W0_I1)
    121         self.__gradList_I1_O1.append(self.__grad_I1_O1)
    122         self.__gradList_O1_I2.append(self.__grad_O1_I2)
    123         self.__gradList_W1_I2.append(self.__grad_W1_I2)
    124         self.__gradList_I2_L.append(self.__grad_I2_L)
    125         
    126         
    127     def __clr_gradList(self):
    128         self.__gradList_W0_I1.clear()
    129         self.__gradList_I1_O1.clear()
    130         self.__gradList_O1_I2.clear()
    131         self.__gradList_W1_I2.clear()
    132         self.__gradList_I2_L.clear()
    133         
    134         
    135     def __init_weights(self):
    136         self.__W0 = numpy.zeros((4, 5))
    137         self.__W1 = numpy.zeros((6, 3))
    138         
    139         
    140     def __update_grad_by_calc_xylv(self):
    141         self.__grad_W0_I1 = numpy.tile(self.__O0.T, (1, 5))
    142         self.__grad_I1_O1 = 1 - self.__I1_tanh ** 2
    143         self.__grad_O1_I2 = self.__W1.T[:, :-1]
    144         self.__grad_W1_I2 = numpy.tile(self.__O1.T, (1, 3))
    145         
    146         
    147     def __calc_xylv(self, rgb):
    148         '''
    149         rgb: shape=(1, 3)之numpy array
    150         '''
    151         self.__rgb = rgb
    152         self.__O0 = numpy.hstack((self.__rgb, [[1]]))
    153         self.__I1 = numpy.matmul(self.__O0, self.__W0)
    154         self.__I1_tanh = numpy.tanh(self.__I1)
    155         self.__O1 = numpy.hstack((self.__I1_tanh, [[1]]))
    156         self.__I2 = numpy.matmul(self.__O1, self.__W1)
    157         
    158         if self.__bpTag:
    159             self.__update_grad_by_calc_xylv()
    160             
    161             
    162     def __update_grad_by_calc_loss(self):
    163         self.__grad_I2_L = self.__I2 - self.__xylv_
    164             
    165             
    166     def __calc_loss(self, xylv_):
    167         '''
    168         xylv_: shape=(1, 3)之numpy array
    169         '''
    170         self.__xylv_ = xylv_
    171         self.__L = numpy.sum((self.__xylv_ - self.__I2) ** 2) / 2
    172         
    173         if self.__bpTag:
    174             self.__update_grad_by_calc_loss()
    175 
    176 
    177 class Adam(object):
    178 
    179     def __init__(self, _func, _grad, _seed):
    180         '''
    181         _func: 待优化目标函数
    182         _grad: 待优化目标函数之梯度
    183         _seed: 迭代起始点
    184         '''
    185         self.__func = _func
    186         self.__grad = _grad
    187         self.__seed = _seed
    188 
    189         self.__xPath = list()
    190         self.__JPath = list()
    191 
    192 
    193     def get_solu(self, alpha=0.001, beta1=0.9, beta2=0.999, epsilon=1.e-8, zeta=1.e-6, maxIter=3000000):
    194         '''
    195         获取数值解,
    196         alpha: 步长参数
    197         beta1: 一阶矩指数衰减率
    198         beta2: 二阶矩指数衰减率
    199         epsilon: 足够小正数
    200         zeta: 收敛判据
    201         maxIter: 最大迭代次数
    202         '''
    203         self.__init_path()
    204 
    205         x = self.__init_x()
    206         JVal = self.__calc_JVal(x)
    207         self.__add_path(x, JVal)
    208         grad = self.__calc_grad(x)
    209         m, v = numpy.zeros(x.shape), numpy.zeros(x.shape)
    210         for k in range(1, maxIter + 1):
    211             print("iterCnt: {:3d},   JVal: {}".format(k, JVal))
    212             if self.__converged1(grad, zeta):
    213                 self.__print_MSG(x, JVal, k)
    214                 return x, JVal, True
    215 
    216             m = beta1 * m + (1 - beta1) * grad
    217             v = beta2 * v + (1 - beta2) * grad * grad
    218             m_ = m / (1 - beta1 ** k)
    219             v_ = v / (1 - beta2 ** k)
    220 
    221             alpha_ = alpha / (numpy.sqrt(v_) + epsilon)
    222             d = -m_
    223             xNew = x + alpha_ * d
    224             JNew = self.__calc_JVal(xNew)
    225             self.__add_path(xNew, JNew)
    226             if self.__converged2(xNew - x, JNew - JVal, zeta ** 2):
    227                 self.__print_MSG(xNew, JNew, k + 1)
    228                 return xNew, JNew, True
    229 
    230             gNew = self.__calc_grad(xNew)
    231             x, JVal, grad = xNew, JNew, gNew
    232         else:
    233             if self.__converged1(grad, zeta):
    234                 self.__print_MSG(x, JVal, maxIter)
    235                 return x, JVal, True
    236 
    237         print("Adam not converged after {} steps!".format(maxIter))
    238         return x, JVal, False
    239 
    240 
    241     def get_path(self):
    242         return self.__xPath, self.__JPath
    243 
    244 
    245     def __converged1(self, grad, epsilon):
    246         if numpy.linalg.norm(grad, ord=numpy.inf) < epsilon:
    247             return True
    248         return False
    249 
    250 
    251     def __converged2(self, xDelta, JDelta, epsilon):
    252         val1 = numpy.linalg.norm(xDelta, ord=numpy.inf)
    253         val2 = numpy.abs(JDelta)
    254         if val1 < epsilon or val2 < epsilon:
    255             return True
    256         return False
    257 
    258 
    259     def __print_MSG(self, x, JVal, iterCnt):
    260         print("Iteration steps: {}".format(iterCnt))
    261         print("Solution:\n{}".format(x.flatten()))
    262         print("JVal: {}".format(JVal))
    263 
    264 
    265     def __calc_JVal(self, x):
    266         return self.__func(x)
    267 
    268 
    269     def __calc_grad(self, x):
    270         return self.__grad(x)
    271 
    272 
    273     def __init_x(self):
    274         return self.__seed
    275 
    276 
    277     def __init_path(self):
    278         self.__xPath.clear()
    279         self.__JPath.clear()
    280 
    281 
    282     def __add_path(self, x, JVal):
    283         self.__xPath.append(x)
    284         self.__JPath.append(JVal)
    285         
    286         
    287 class BPExPlot(object):
    288 
    289     @staticmethod
    290     def plot_fig(adamObj):
    291         alpha = 0.001
    292         epoch = 50000
    293         x, JVal, tab = adamObj.get_solu(alpha=alpha, maxIter=epoch)
    294         xPath, JPath = adamObj.get_path()
    295         
    296         fig = plt.figure(figsize=(6, 4))
    297         ax1 = fig.add_subplot(1, 1, 1)
    298         
    299         ax1.plot(numpy.arange(len(JPath)), JPath, "k.", markersize=1)
    300         ax1.plot(0, JPath[0], "go", label="seed")
    301         ax1.plot(len(JPath)-1, JPath[-1], "r*", label="solution")
    302         
    303         ax1.legend()
    304         ax1.set(xlabel="$epoch$", ylabel="$JVal$", title="solution-JVal = {:.5f}".format(JPath[-1]))
    305         
    306         fig.tight_layout()
    307         # plt.show()
    308         fig.savefig("plot_fig.png", dpi=100)
    309         
    310         
    311         
    312 if __name__ == "__main__":
    313     RGB, XYLv_ = getData(1000)
    314     bpObj = BPEx(RGB, XYLv_)
    315     
    316     # rgb = (0.5, 0.6, 0.7)
    317     # xylv = bpObj.calc_xylv(rgb)
    318     # print(rgb)
    319     # print(xylv)
    320     # func = bpObj.calc_JVal
    321     # grad = bpObj.calc_grad
    322     # seed = bpObj.init_seed()
    323     # adamObj = Adam(func, grad, seed)
    324     # alpha = 0.1
    325     # epoch = 1000
    326     # adamObj.get_solu(alpha=alpha, maxIter=epoch)
    327     # xylv = bpObj.calc_xylv(rgb)
    328     # print(rgb)
    329     # print(xylv)
    330     # W0, W1 = bpObj.get_W0W1()
    331     # xylv = bpObj.calc_xylv(rgb, W0, W1)
    332     # print(rgb)
    333     # print(xylv)
    334     
    335     func = bpObj.calc_JVal
    336     grad = bpObj.calc_grad
    337     seed = bpObj.init_seed()
    338     adamObj = Adam(func, grad, seed)
    339     BPExPlot.plot_fig(adamObj)
    View Code
  • 结果展示

    可以看到, 在training data上总体loss随epoch增加逐渐降低.

  • 使用建议
    ①. 某层之output节点可能受多个该层之input节点影响(如: softmax激活函数), 此时input节点具备source特性;
    ②. 正向计算过程可确定基础影响链路之梯度, 反向传播过程可串联基础影响链路之梯度;
    ③. 初值的选取对非凸问题的优化比较重要, 权重初值尽量不取全0.
  • 参考文档
    ①. Rumelhart D E, Hinton G E, Williams R J. Learning internal representations by error propagation[R]. California Univ San Diego La Jolla Inst for Cognitive Science, 1985.
    ②. Adam (1) - Python实现
posted @ 2021-10-07 13:04  LOGAN_XIONG  阅读(135)  评论(0编辑  收藏  举报