基于pytorch对函数进行极值求解

 

 1 import numpy as np
 2 from mpl_toolkits.mplot3d import Axes3D
 3 import matplotlib.pyplot as plt
 4 from matplotlib.colors import LinearSegmentedColormap
 5 
 6 # 待求极值的函数
 7 def himmelblau(t):# t[0]-->X; t[1]-->Y.
 8     return (t[0] ** 2 + t[1] - 11) ** 2 + (t[0] + t[1] ** 2 - 7) ** 2
 9 
10 x = np.arange(-6, 6, 0.1)
11 y = np.arange(-6, 6, 0.1)
12 X, Y = np.meshgrid(x, y)
13 Z = himmelblau([X, Y])
14 fig = plt.figure()
15 ax = fig.add_subplot(projection='3d')# ax = fig.gca(projection='3d') # ---> was deprecated in Matplotlib 3.4
16 ax.plot_surface(X, Y, Z)
17 ax.view_init(60, -30)
18 ax.set_xlabel('x')
19 ax.set_ylabel('y')
20 fig.show()
21 plt.show()
22 
23 # function test
24 def jeshy(t):
25     return t*3+10
26 
27 import torch
28 x = torch.tensor([0., 0.], requires_grad=True)
29 optimizer = torch.optim.Adam([x, ])# optim.Adam([var1, var2], lr=0.0001)# 优化器设置 ,并传入模型参数和相应的学习率
30 for step in range(20001):
31     f = himmelblau(x)# 前向传播
32     if step > 0:
33         optimizer.zero_grad()# 反向传播与优化# 清空上一步的残余更新参数值
34         f.backward(retain_graph=True)# 反向传播与优化# 反向传播
35         optimizer.step()# 反向传播与优化# 将参数更新值施加到函数f的parameters上
36     # f = jeshy(f)
37     if step % 1000 == 0:# 每迭代一定步骤,打印结果值
38         print('step:{}, x = {}, value = {}'.format(step, x.tolist(), f))

 

 输出:

step:0, x = [0.0, 0.0], value = 170.0
step:1000, x = [1.270142912864685, 1.1183991432189941], value = 88.53223419189453
step:2000, x = [2.332378387451172, 1.9535712003707886], value = 13.766233444213867
step:3000, x = [2.8519949913024902, 2.114161968231201], value = 0.6711398363113403
step:4000, x = [2.981964111328125, 2.0271568298339844], value = 0.014927156269550323
step:5000, x = [2.9991261959075928, 2.0014777183532715], value = 3.9870232285466045e-05
step:6000, x = [2.999983549118042, 2.0000221729278564], value = 1.1074007488787174e-08
step:7000, x = [2.9999899864196777, 2.000013589859009], value = 4.150251697865315e-09
step:8000, x = [2.9999938011169434, 2.0000083446502686], value = 1.5572823031106964e-09
step:9000, x = [2.9999964237213135, 2.000005006790161], value = 5.256879376247525e-10
step:10000, x = [2.999997854232788, 2.000002861022949], value = 1.8189894035458565e-10
step:11000, x = [2.9999988079071045, 2.0000014305114746], value = 5.547917680814862e-11
step:12000, x = [2.9999992847442627, 2.0000009536743164], value = 1.6370904631912708e-11
step:13000, x = [2.999999523162842, 2.000000476837158], value = 5.6843418860808015e-12
step:14000, x = [2.999999761581421, 2.000000238418579], value = 1.8189894035458565e-12
step:15000, x = [3.0, 2.0], value = 0.0
step:16000, x = [3.0, 2.0], value = 0.0
step:17000, x = [3.0, 2.0], value = 0.0
step:18000, x = [3.0, 2.0], value = 0.0
step:19000, x = [3.0, 2.0], value = 0.0
step:20000, x = [3.0, 2.0], value = 0.0

posted @ 2021-10-26 10:54  土博姜山山  阅读(247)  评论(0编辑  收藏  举报