scikit-opt——SA(模拟退火)
一、SA求函数最值
第一步: 定义您的问题
demo_func = lambda x: x[0] ** 2 + (x[1] - 0.05) ** 2 + x[2] ** 2
第二步:执行SA
from sko.SA import SA sa = SA(func=demo_func, x0=[1, 1, 1], T_max=1, T_min=1e-9, L=300, max_stay_counter=150) best_x, best_y = sa.run() print('best_x:', best_x, 'best_y', best_y)
第三步:绘制结果
import matplotlib.pyplot as plt import pandas as pd plt.plot(pd.DataFrame(sa.best_y_history).cummin(axis=0)) plt.show()
而且,scikit-opt提供了3种类型的模拟退火:快速,玻尔兹曼,柯西。查看更多sa
二、SA解决TSP问题
第一步:定义问题。TSP是什么自己百度。
file_name = sys.argv[1] if len(sys.argv) > 1 else 'data/nctu.csv' points_coordinate = np.loadtxt(file_name, delimiter=',') num_points = points_coordinate.shape[0] distance_matrix = spatial.distance.cdist(points_coordinate, points_coordinate, metric='euclidean') distance_matrix = distance_matrix * 111000 # 1 degree of lat/lon ~ = 111000m def cal_total_distance(routine): '''The objective function. input routine, return total distance. cal_total_distance(np.arange(num_points)) ''' num_points, = routine.shape return sum([distance_matrix[routine[i % num_points], routine[(i + 1) % num_points]] for i in range(num_points)])
读取数据nctu.csv,定义距离计算函数。
第二步:为TSP做SA
from sko.SA import SA_TSP sa_tsp = SA_TSP(func=cal_total_distance, x0=range(num_points), T_max=100, T_min=1, L=10 * num_points) best_points, best_distance = sa_tsp.run() print(best_points, best_distance, cal_total_distance(best_points))
第三步:绘制结果
# %% Plot the best routine from matplotlib.ticker import FormatStrFormatter fig, ax = plt.subplots(1, 2) best_points_ = np.concatenate([best_points, [best_points[0]]]) best_points_coordinate = points_coordinate[best_points_, :] ax[0].plot(sa_tsp.best_y_history) ax[0].set_xlabel("Iteration") ax[0].set_ylabel("Distance") ax[1].plot(best_points_coordinate[:, 0], best_points_coordinate[:, 1], marker='o', markerfacecolor='b', color='c', linestyle='-') ax[1].xaxis.set_major_formatter(FormatStrFormatter('%.3f')) ax[1].yaxis.set_major_formatter(FormatStrFormatter('%.3f')) ax[1].set_xlabel("Longitude") ax[1].set_ylabel("Latitude") plt.show()
更多:绘制动画
# %% Plot the animation import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation from matplotlib.ticker import FormatStrFormatter best_x_history = sa_tsp.best_x_history fig2, ax2 = plt.subplots(1, 1) ax2.set_title('title', loc='center') line = ax2.plot(points_coordinate[:, 0], points_coordinate[:, 1], marker='o', markerfacecolor='b', color='c', linestyle='-') ax2.xaxis.set_major_formatter(FormatStrFormatter('%.3f')) ax2.yaxis.set_major_formatter(FormatStrFormatter('%.3f')) ax2.set_xlabel("Longitude") ax2.set_ylabel("Latitude") plt.ion() p = plt.show() def update_scatter(frame): ax2.set_title('iter = ' + str(frame)) points = best_x_history[frame] points = np.concatenate([points, [points[0]]]) point_coordinate = points_coordinate[points, :] plt.setp(line, 'xdata', point_coordinate[:, 0], 'ydata', point_coordinate[:, 1]) return line ani = FuncAnimation(fig2, update_scatter, blit=True, interval=25, frames=len(best_x_history)) plt.show() ani.save('sa_tsp.gif', writer='pillow')
参考链接:scikit-opt官方文档-SA部分
个性签名:时间会解决一切