def f(x):
return x ** 3 - 2 * x + 1 # 返回函数的值
def f1(s0, s1, s2):
return (((s1 ** 2 - s2 ** 3) * f(s0) + (s2 ** 2 - s0 ** 2) * f(s1) + (s0 ** 2 - s1 ** 2) * f(s2)) / ((s1 - s0) * f(
s0) + (s2 - s0) * f(s1) + (s0 - s1) * f(s2))) / 2 # 计算s^*
def solve(s0, h, epsilon):
s1 = s0 + h #计算s1
s2 = s1 + 2 * h # 计算s2
s = f1(s0, s1, s2) # 计算s*
phi0 = f(s0) # 计算\phi_0
phi1 = f(s1)
phi2 = f(s2)
phi = f(s)
while True:
if abs(s2 - s0) <= epsilon:
return s0 # 达到精度时返回s_0,返回\frac{s0+s2}{2}也可以
break # 并结束循环
else:
s = f1(s0, s1, s2) # 否则则计算s^*
phi = f(s) # 计算\phi_x^*
if phi1 <= phi:
if s1 < s: # 如果s^*在s1右边
s2 = s # 将s^*赋值给s2
phi2 = phi # 避免再次计算\phi,将\phi_*赋值给\phi_2
continue
else: # 如果s^*在s1右边
s0 = s # 将s_*赋值给s_0
phi0 = phi # 将\phi赋值给\phi_0,减少运算量
continue # 结束当次循环
else: # 如果\phi_1>\phi
if s1 > s: # s^*在s_1左边
s2 = s1
s1 = s
phi2 = phi1
phi1 = phi
continue
else:
s0 = s1
s1 = s
phi0 = phi1
phi1 = phi
continue
if __name__ == '__main__':
print(solve(1, 0.1, 0.05)) # 传入参数s_0, h, \epsilon