import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.axes3d import Axes3D
# 公式 f(x,y) = 2x^2+6y^2+6xy+x+4y+8
def targetFunc(x,y):
return 2*(x**2)+6*y**2+6*x*y+x+4*y+8
# 偏导
# f'x(x,y)=4x+6y+1
# f'y(x,y)=12y+6x+4
def derivativeFunc(x,y):
rx = 4*x+6*y+1
ry = 12*y+6*x+4
return (rx,ry)
pointList = []
def linerFunc(initPoint:tuple,targetFunc,derivativeFunc,step = 0.01,limitValue = 0.00000001,timeout=1000000,ax:Axes3D = None):
count = 1
initPoint = np.array(initPoint)
ro,do = targetFunc(*initPoint),np.array(derivativeFunc(*initPoint))
pointList.append((*initPoint, ro))
newPoint = initPoint-do*step
rn,dn = targetFunc(*newPoint),np.array(derivativeFunc(*newPoint))
diff = np.abs(np.array(do-dn))
while (diff > limitValue).any() and count < timeout:
# print(initPoint)
initPoint = newPoint
ro, do = targetFunc(*initPoint), np.array(derivativeFunc(*initPoint))
newPoint = np.where(np.abs(do*step) >= limitValue,initPoint-do*step,initPoint)
rn, dn = targetFunc(*newPoint), np.array(derivativeFunc(*newPoint))
diff = np.abs(np.array(do - dn))
pointList.append((*initPoint, ro))
count+=1
pass
print("最终运算次数为 : {0}".format(count))
return rn,newPoint
pass
if __name__=="__main__":
x,y = np.linspace(-2,23,100),np.linspace(-2,23,100)
x,y = np.meshgrid(x,y)
fxy=targetFunc(x,y)
fig = plt.figure()
ax = Axes3D(fig)
ax.plot_surface(x, y, fxy)
limitValue,limitPoint = linerFunc((20,20),targetFunc,derivativeFunc,ax=ax)
ax.scatter(*(np.array(pointList).T),c='r',s=20)
print("该函数在({0},{1})处有驻点,值为{2}".format(limitPoint[0],limitPoint[1],limitValue))
ax.legend()
plt.show()
pass