Adam (1) - Python实现

  • 算法特征
    ①. 梯度凸组合控制迭代方向; ②. 梯度平方凸组合控制迭代步长; ③. 各优化变量自适应搜索.
  • 算法推导
    Part Ⅰ 算法细节
    拟设目标函数符号为$J$, 则梯度表示如下,
    \begin{equation}
    g = \nabla J
    \label{eq_1}
    \end{equation}
    参考Momentum Gradient, 对梯度凸组合控制迭代方向first momentum,
    \begin{equation}
    m_{k} = \beta_1m_{k-1} + (1 - \beta_1)g_{k}
    \label{eq_2}
    \end{equation}
    其中, $\beta_1$是凸组合系数, 也是指数衰减率.
    参考RMSProp, 对梯度平方凸组合控制迭代步长second raw momentum,
    \begin{equation}
    v_{k} = \beta_2v_{k-1} + (1 - \beta_2)g_{k}\odot g_{k}
    \label{eq_3}
    \end{equation}
    其中, $\beta_2$是凸组合系数, 也是指数衰减率.
    由于first momentum与second raw momentum均初始化为0, 分别以如下方式修正以降低凸组合系数对初始迭代的影响,
    \begin{gather}
    \hat{m}_{k} = \frac{m_{k}}{1 - \beta_1^{k}}\label{eq_4} \\
    \hat{v}_{k} = \frac{v_{k}}{1 - \beta_2^{k}}\label{eq_5}
    \end{gather}
    不失一般性, 令第$k$步迭代形式如下,
    \begin{equation}
    x_{k+1} = x_k + \alpha_kd_k
    \label{eq_6}
    \end{equation}
    其中, $\alpha_k$、$d_k$分别代表第$k$步迭代步长与迭代方向, 且
    \begin{gather}
    \alpha_k = \frac{\alpha}{\sqrt{\hat{v}_k} + \epsilon}\label{eq_7} \\
    d_k = -\hat{m}_k\label{eq_8}
    \end{gather}
    其中, $\alpha$代表步长参数, $\epsilon$取值足够小正数避免迭代步长分母为0.
    Part Ⅱ 算法流程
    初始化步长参数$\alpha$、足够小正数$\epsilon$、指数衰减率$\beta_1$、指数衰减率$\beta_2$
    初始化收敛判据$\zeta$、迭代起点$x_1$
    计算当前梯度值$g_1=\nabla J(x_1)$, 令: 一阶矩$m_0 = 0$, 二阶矩$v_0 = 0$, $k = 1$, 重复以下步骤,
      step1: 如果$\|g_k\| < \zeta$, 收敛, 迭代停止
      step2: 更新一阶矩$m_k = \beta_1m_{k-1} + (1 - \beta_1)g_{k}$
      step3: 更新二阶矩$v_k = \beta_2v_{k-1} + (1 - \beta_2)g_{k}\odot g_{k}$
      step4: 计算一阶矩修正$\displaystyle \hat{m}_{k} = \frac{m_{k}}{1 - \beta_1^{k}}$
      step5: 计算二阶矩修正$\displaystyle \hat{v}_{k} = \frac{v_{k}}{1 - \beta_2^{k}}$
      step6: 计算迭代步长$\displaystyle \alpha_k = \frac{\alpha}{\sqrt{\hat{v}_k} + \epsilon}$
      step7: 计算迭代方向$d_k = -\hat{m}_k$
      step8: 更新迭代点$x_{k+1} = x_k + \alpha_kd_k$
      step9: 更新梯度值$g_{k+1}=\nabla J(x_{k+1})$
      step10: 令$k = k+1$, 转step1
  • 代码实现
    现以如下无约束凸优化问题为例进行算法实施,
    \begin{equation*}
    \min\quad 5x_1^2 + 2x_2^2 + 3x_1 - 10x_2 + 4
    \end{equation*}
    Adam实现如下,
      1 # Adam之实现
      2 
      3 import numpy
      4 from matplotlib import pyplot as plt
      5 
      6 
      7 # 目标函数0阶信息
      8 def func(X):
      9     funcVal = 5 * X[0, 0] ** 2 + 2 * X[1, 0] ** 2 + 3 * X[0, 0] - 10 * X[1, 0] + 4
     10     return funcVal
     11     
     12     
     13 # 目标函数1阶信息
     14 def grad(X):
     15     grad_x1 = 10 * X[0, 0] + 3
     16     grad_x2 = 4 * X[1, 0] - 10
     17     gradVec = numpy.array([[grad_x1], [grad_x2]])
     18     return gradVec
     19     
     20     
     21 # 定义迭代起点
     22 def seed(n=2):
     23     seedVec = numpy.random.uniform(-100, 100, (n, 1))
     24     return seedVec
     25     
     26     
     27 class Adam(object):
     28     
     29     def __init__(self, _func, _grad, _seed):
     30         '''
     31         _func: 待优化目标函数
     32         _grad: 待优化目标函数之梯度
     33         _seed: 迭代起始点
     34         '''
     35         self.__func = _func
     36         self.__grad = _grad
     37         self.__seed = _seed
     38         
     39         self.__xPath = list()
     40         self.__JPath = list()
     41         
     42         
     43     def get_solu(self, alpha=0.001, beta1=0.9, beta2=0.999, epsilon=1.e-8, zeta=1.e-6, maxIter=3000000):
     44         '''
     45         获取数值解,
     46         alpha: 步长参数
     47         beta1: 一阶矩指数衰减率
     48         beta2: 二阶矩指数衰减率
     49         epsilon: 足够小正数
     50         zeta: 收敛判据
     51         maxIter: 最大迭代次数
     52         '''
     53         self.__init_path()
     54         
     55         x = self.__init_x()
     56         JVal = self.__calc_JVal(x)
     57         self.__add_path(x, JVal)
     58         grad = self.__calc_grad(x)
     59         m, v = numpy.zeros(x.shape), numpy.zeros(x.shape)
     60         for k in range(1, maxIter + 1):
     61             # print("k: {:3d},   JVal: {}".format(k, JVal))
     62             if self.__converged1(grad, zeta):
     63                 self.__print_MSG(x, JVal, k)
     64                 return x, JVal, True
     65             
     66             m = beta1 * m + (1 - beta1) * grad
     67             v = beta2 * v + (1 - beta2) * grad * grad
     68             m_ = m / (1 - beta1 ** k)
     69             v_ = v / (1 - beta2 ** k)
     70             
     71             alpha_ = alpha / (numpy.sqrt(v_) + epsilon)
     72             d = -m_
     73             xNew = x + alpha_ * d
     74             JNew = self.__calc_JVal(xNew)
     75             self.__add_path(xNew, JNew)
     76             if self.__converged2(xNew - x, JNew - JVal, zeta ** 2):
     77                 self.__print_MSG(xNew, JNew, k + 1)
     78                 return xNew, JNew, True
     79                 
     80             gNew = self.__calc_grad(xNew)
     81             x, JVal, grad = xNew, JNew, gNew
     82         else:
     83             if self.__converged1(grad, zeta):
     84                 self.__print_MSG(x, JVal, maxIter)
     85                 return x, JVal, True
     86                 
     87         print("Adam not converged after {} steps!".format(maxIter))
     88         return x, JVal, False
     89         
     90         
     91     def get_path(self):
     92         return self.__xPath, self.__JPath
     93             
     94             
     95     def __converged1(self, grad, epsilon):
     96         if numpy.linalg.norm(grad, ord=numpy.inf) < epsilon:
     97             return True
     98         return False
     99         
    100         
    101     def __converged2(self, xDelta, JDelta, epsilon):
    102         val1 = numpy.linalg.norm(xDelta, ord=numpy.inf)
    103         val2 = numpy.abs(JDelta)
    104         if val1 < epsilon or val2 < epsilon:
    105             return True
    106         return False
    107         
    108         
    109     def __print_MSG(self, x, JVal, iterCnt):
    110         print("Iteration steps: {}".format(iterCnt))
    111         print("Solution:\n{}".format(x.flatten()))
    112         print("JVal: {}".format(JVal))
    113         
    114         
    115     def __calc_JVal(self, x):
    116         return self.__func(x)
    117         
    118         
    119     def __calc_grad(self, x):
    120         return self.__grad(x)
    121         
    122         
    123     def __init_x(self):
    124         return self.__seed
    125         
    126         
    127     def __init_path(self):
    128         self.__xPath.clear()
    129         self.__JPath.clear()
    130         
    131         
    132     def __add_path(self, x, JVal):
    133         self.__xPath.append(x)
    134         self.__JPath.append(JVal)
    135         
    136                 
    137 class AdamPlot(object):
    138     
    139     @staticmethod
    140     def plot_fig(adamObj):
    141         x, JVal, tab = adamObj.get_solu(0.1)
    142         xPath, JPath = adamObj.get_path()
    143         
    144         fig = plt.figure(figsize=(10, 4))
    145         ax1 = plt.subplot(1, 2, 1)
    146         ax2 = plt.subplot(1, 2, 2)
    147         
    148         ax1.plot(numpy.arange(len(JPath)), JPath, "k.", markersize=1)
    149         ax1.plot(0, JPath[0], "go", label="starting point")
    150         ax1.plot(len(JPath)-1, JPath[-1], "r*", label="solution")
    151         
    152         ax1.legend()
    153         ax1.set(xlabel="$iterCnt$", ylabel="$JVal$")
    154         
    155         x1 = numpy.linspace(-100, 100, 300)
    156         x2 = numpy.linspace(-100, 100, 300)
    157         x1, x2 = numpy.meshgrid(x1, x2)
    158         f = numpy.zeros(x1.shape)
    159         for i in range(x1.shape[0]):
    160             for j in range(x1.shape[1]):
    161                 f[i, j] = func(numpy.array([[x1[i, j]], [x2[i, j]]]))
    162         ax2.contour(x1, x2, f, levels=36)
    163         x1Path = list(item[0] for item in xPath)
    164         x2Path = list(item[1] for item in xPath)
    165         ax2.plot(x1Path, x2Path, "k--", lw=2)
    166         ax2.plot(x1Path[0], x2Path[0], "go", label="starting point")
    167         ax2.plot(x1Path[-1], x2Path[-1], "r*", label="solution")
    168         ax2.set(xlabel="$x_1$", ylabel="$x_2$")
    169         ax2.legend()
    170                 
    171         fig.tight_layout()
    172         # plt.show()
    173         fig.savefig("plot_fig.png")
    174 
    175         
    176         
    177 if __name__ == "__main__":
    178     adamObj = Adam(func, grad, seed())
    179     
    180     AdamPlot.plot_fig(adamObj)
    View Code
  • 结果展示
  • 使用建议
    ①. 局部二阶矩求和一定程度上反应了局部的曲率信息, 用以近似并替代Hessian矩阵是合理的;
    ②. 文献中初始化参数推荐$\alpha=0.001, \beta_1=0.9, \beta_2=0.999, \epsilon=10^{-8}$, 实际根据需要优先调整步长参数$\alpha$.
  • 参考文档
    Kingma D P, Ba J. Adam: A method for stochastic optimization[J]. arXiv preprint arXiv:1412.6980, 2014.
posted @ 2021-07-26 23:42  LOGAN_XIONG  阅读(1803)  评论(0编辑  收藏  举报