1 # Adam之实现
2
3 import numpy
4 from matplotlib import pyplot as plt
5
6
7 # 目标函数0阶信息
8 def func(X):
9 funcVal = 5 * X[0, 0] ** 2 + 2 * X[1, 0] ** 2 + 3 * X[0, 0] - 10 * X[1, 0] + 4
10 return funcVal
11
12
13 # 目标函数1阶信息
14 def grad(X):
15 grad_x1 = 10 * X[0, 0] + 3
16 grad_x2 = 4 * X[1, 0] - 10
17 gradVec = numpy.array([[grad_x1], [grad_x2]])
18 return gradVec
19
20
21 # 定义迭代起点
22 def seed(n=2):
23 seedVec = numpy.random.uniform(-100, 100, (n, 1))
24 return seedVec
25
26
27 class Adam(object):
28
29 def __init__(self, _func, _grad, _seed):
30 '''
31 _func: 待优化目标函数
32 _grad: 待优化目标函数之梯度
33 _seed: 迭代起始点
34 '''
35 self.__func = _func
36 self.__grad = _grad
37 self.__seed = _seed
38
39 self.__xPath = list()
40 self.__JPath = list()
41
42
43 def get_solu(self, alpha=0.001, beta1=0.9, beta2=0.999, epsilon=1.e-8, zeta=1.e-6, maxIter=3000000):
44 '''
45 获取数值解,
46 alpha: 步长参数
47 beta1: 一阶矩指数衰减率
48 beta2: 二阶矩指数衰减率
49 epsilon: 足够小正数
50 zeta: 收敛判据
51 maxIter: 最大迭代次数
52 '''
53 self.__init_path()
54
55 x = self.__init_x()
56 JVal = self.__calc_JVal(x)
57 self.__add_path(x, JVal)
58 grad = self.__calc_grad(x)
59 m, v = numpy.zeros(x.shape), numpy.zeros(x.shape)
60 for k in range(1, maxIter + 1):
61 # print("k: {:3d}, JVal: {}".format(k, JVal))
62 if self.__converged1(grad, zeta):
63 self.__print_MSG(x, JVal, k)
64 return x, JVal, True
65
66 m = beta1 * m + (1 - beta1) * grad
67 v = beta2 * v + (1 - beta2) * grad * grad
68 m_ = m / (1 - beta1 ** k)
69 v_ = v / (1 - beta2 ** k)
70
71 alpha_ = alpha / (numpy.sqrt(v_) + epsilon)
72 d = -m_
73 xNew = x + alpha_ * d
74 JNew = self.__calc_JVal(xNew)
75 self.__add_path(xNew, JNew)
76 if self.__converged2(xNew - x, JNew - JVal, zeta ** 2):
77 self.__print_MSG(xNew, JNew, k + 1)
78 return xNew, JNew, True
79
80 gNew = self.__calc_grad(xNew)
81 x, JVal, grad = xNew, JNew, gNew
82 else:
83 if self.__converged1(grad, zeta):
84 self.__print_MSG(x, JVal, maxIter)
85 return x, JVal, True
86
87 print("Adam not converged after {} steps!".format(maxIter))
88 return x, JVal, False
89
90
91 def get_path(self):
92 return self.__xPath, self.__JPath
93
94
95 def __converged1(self, grad, epsilon):
96 if numpy.linalg.norm(grad, ord=numpy.inf) < epsilon:
97 return True
98 return False
99
100
101 def __converged2(self, xDelta, JDelta, epsilon):
102 val1 = numpy.linalg.norm(xDelta, ord=numpy.inf)
103 val2 = numpy.abs(JDelta)
104 if val1 < epsilon or val2 < epsilon:
105 return True
106 return False
107
108
109 def __print_MSG(self, x, JVal, iterCnt):
110 print("Iteration steps: {}".format(iterCnt))
111 print("Solution:\n{}".format(x.flatten()))
112 print("JVal: {}".format(JVal))
113
114
115 def __calc_JVal(self, x):
116 return self.__func(x)
117
118
119 def __calc_grad(self, x):
120 return self.__grad(x)
121
122
123 def __init_x(self):
124 return self.__seed
125
126
127 def __init_path(self):
128 self.__xPath.clear()
129 self.__JPath.clear()
130
131
132 def __add_path(self, x, JVal):
133 self.__xPath.append(x)
134 self.__JPath.append(JVal)
135
136
137 class AdamPlot(object):
138
139 @staticmethod
140 def plot_fig(adamObj):
141 x, JVal, tab = adamObj.get_solu(0.1)
142 xPath, JPath = adamObj.get_path()
143
144 fig = plt.figure(figsize=(10, 4))
145 ax1 = plt.subplot(1, 2, 1)
146 ax2 = plt.subplot(1, 2, 2)
147
148 ax1.plot(numpy.arange(len(JPath)), JPath, "k.", markersize=1)
149 ax1.plot(0, JPath[0], "go", label="starting point")
150 ax1.plot(len(JPath)-1, JPath[-1], "r*", label="solution")
151
152 ax1.legend()
153 ax1.set(xlabel="$iterCnt$", ylabel="$JVal$")
154
155 x1 = numpy.linspace(-100, 100, 300)
156 x2 = numpy.linspace(-100, 100, 300)
157 x1, x2 = numpy.meshgrid(x1, x2)
158 f = numpy.zeros(x1.shape)
159 for i in range(x1.shape[0]):
160 for j in range(x1.shape[1]):
161 f[i, j] = func(numpy.array([[x1[i, j]], [x2[i, j]]]))
162 ax2.contour(x1, x2, f, levels=36)
163 x1Path = list(item[0] for item in xPath)
164 x2Path = list(item[1] for item in xPath)
165 ax2.plot(x1Path, x2Path, "k--", lw=2)
166 ax2.plot(x1Path[0], x2Path[0], "go", label="starting point")
167 ax2.plot(x1Path[-1], x2Path[-1], "r*", label="solution")
168 ax2.set(xlabel="$x_1$", ylabel="$x_2$")
169 ax2.legend()
170
171 fig.tight_layout()
172 # plt.show()
173 fig.savefig("plot_fig.png")
174
175
176
177 if __name__ == "__main__":
178 adamObj = Adam(func, grad, seed())
179
180 AdamPlot.plot_fig(adamObj)