| from sklearn import metrics |
| from sklearn.metrics import mean_squared_error |
| from sklearn.model_selection import cross_val_score |
| |
| from sklearn.model_selection import GridSearchCV |
| |
| import matplotlib.pyplot as plt |
| import matplotlib as mpl |
| mpl.rcParams['font.sans-serif'] = ['SimHei'] |
| mpl.rcParams['axes.unicode_minus'] = False |
| from sklearn.datasets import make_regression |
随机森林
| from sklearn.ensemble import RandomForestRegressor |
| X, y = make_regression(n_features=4, n_informative=2, |
| random_state=0, shuffle=False) |


| X_train=X[:70] |
| y_train=y[:70] |
| |
| X_test=X[:30] |
| y_test=y[:30] |
| regr = RandomForestRegressor(max_depth=2, random_state=0) |
| regr.fit(X_train, y_train) |
| |
| |
| y_pred=regr.predict(X_test) |
| print(y_pred) |
[ 41.71152007 -15.51877479 18.77435453 2.4613485 -5.25163664
11.98242971 -28.99147231 67.82781115 -46.47813223 58.94403962
-44.43019803 -25.35127762 -27.46837011 -31.48276853 17.81715876
-25.42572978 -16.172543 -20.43062853 -20.84673413 -30.25425251
17.90104445 67.70073552 28.81417535 33.29761523 40.28058259
-22.61219493 34.50175346 68.835082 38.18859153 -6.48249831]
| |
| |
| t = np.arange(len(X_test)) |
| plt.plot(t, y_test, 'r', linewidth=2, label='真实值') |

| |
| plt.plot(t, y_pred, 'g', linewidth=2, label='预测值') |

| |
| regr.score(X_test,y_test) |
0.8338446596824768
| |
| |
| metrics.mean_squared_error(y_test, y_pred) |
334.42748631188385
RandomForestRegressor(max_depth=2, random_state=0)
| regr.feature_importances_ |
array([0.15597865, 0.84082089, 0. , 0.00320046])
调优——k折交叉验证,scikit-learn的网格搜索GridSearchCV
| |
| param_grid = {"n_estimators":[5,50,100],"max_depth":[8,9,10]} |
| |
| grid_search = GridSearchCV(RandomForestRegressor(),param_grid,cv = 3) |
| grid_search.fit(X_train, y_train) |
GridSearchCV(cv=3, estimator=RandomForestRegressor(),
param_grid={'max_depth': [8, 9, 10], 'n_estimators': [5, 50, 100]})
| y_pred=grid_search.predict(X_test) |
| print(y_pred) |
[ 49.50191561 -0.7122897 15.26286215 17.50407347 15.87708862
-14.54908528 -13.32531612 80.64244515 -75.54860534 63.84753325
-68.76733049 -27.15074728 -34.90857798 -45.24935823 16.53953061
-25.26432862 -10.65729336 -18.79136562 -19.30815651 -38.14527267
6.93420609 88.31726657 16.87408796 34.57068077 53.79849864
-9.89424185 39.75832876 87.18227999 45.21303975 13.54728708]
| plt.figure(figsize=(15, 10)) |
| |
| t = np.arange(len(X_test)) |
| |
| plt.plot(t, y_test, 'r', linewidth=2, label='真实值') |
| |
| plt.plot(t, y_pred, 'g', linewidth=2, label='预测值') |
| |
| plt.legend() |
| plt.show() |

| |
| print("r2:", grid_search.score(X_test, y_test)) |
r2: 0.9866915026043963
| |
| print("MSE:", metrics.mean_squared_error(y_test, y_pred)) |
MSE: 26.786543978030636
GridSearchCV(cv=3, estimator=RandomForestRegressor(),
param_grid={'max_depth': [8, 9, 10], 'n_estimators': [5, 50, 100]})
| print(grid_search.best_params_) |
{'max_depth': 9, 'n_estimators': 50}
调优——k折交叉验证+逐个参数
| superpa = [] |
| for i in range(10,200,10): |
| regr = RandomForestRegressor(n_estimators=i |
| ,random_state=42 |
| ) |
| regr_s = cross_val_score(regr |
| ,X_train |
| ,y_train |
| ,cv=10 |
| |
| ).mean() |
| superpa.append(regr_s) |
| print(max(superpa),superpa.index(max(superpa)),(superpa.index(max(superpa)))*10+10) |
| |
| plt.figure(figsize=[20,5]) |
| plt.plot(range(10,200,10),superpa) |
| plt.show() |

| |
| superpa = [] |
| for i in range(10,30,2): |
| regr = RandomForestRegressor(n_estimators=170 |
| ,max_depth=i |
| ,random_state=42 |
| ) |
| regr_s = cross_val_score(regr |
| ,X_train |
| ,y_train |
| ,cv=10 |
| |
| ).mean() |
| superpa.append(regr_s) |
| print(max(superpa),superpa.index(max(superpa)),(superpa.index(max(superpa)))*2+10) |
| |
| plt.figure(figsize=[20,5]) |
| plt.plot(range(10,30,2),superpa) |
| |
| plt.show() |

| |
| superpa = [] |
| for i in range(2,10,2): |
| regr = RandomForestRegressor(n_estimators=170 |
| ,max_depth=12 |
| ,min_samples_split=i |
| ,random_state=42 |
| ,n_jobs=-1) |
| regr_s = cross_val_score(regr |
| ,X_train |
| ,y_train |
| ,cv=10 |
| |
| ).mean() |
| superpa.append(regr_s) |
| print(max(superpa),superpa.index(max(superpa)),(superpa.index(max(superpa)))*2+2) |
| |
| plt.figure(figsize=[20,5]) |
| plt.plot(range(2,10,2),superpa) |
| plt.show() |

| |
| superpa = [] |
| for i in range(1,15,1): |
| regr = RandomForestRegressor(n_estimators=170 |
| ,max_depth=12 |
| ,min_samples_split=2 |
| ,min_samples_leaf=i |
| ,random_state=42 |
| ) |
| regr_s = cross_val_score(regr |
| ,X_train |
| ,y_train |
| ,cv=10 |
| |
| ).mean() |
| superpa.append(regr_s) |
| print(max(superpa),superpa.index(max(superpa)),(superpa.index(max(superpa)))*1+1) |
| |
| plt.figure(figsize=[20,5]) |
| plt.plot(range(1,15,1),superpa) |
| plt.show() |

| |
| |
| |
| param_grid = {'max_features':np.arange(3, 11, 1)} |
| |
| regr = RandomForestRegressor(n_estimators=170 |
| ,max_depth=12 |
| ,min_samples_split=2 |
| ,min_samples_leaf=1 |
| ,random_state=42 |
| ) |
| GS = GridSearchCV(regr,param_grid,cv=10) |
| GS.fit(X_train, y_train) |
| |
| print(GS.best_params_) |
| |
{'max_features': 4}
最终模型:
| regr = RandomForestRegressor(n_estimators=170 |
| ,max_depth=12 |
| ,min_samples_split=2 |
| ,min_samples_leaf=1 |
| ,random_state=42 |
| ) |
| |
| regr.score(X_test,y_test) |
0.9879834206877871
| |
| metrics.mean_squared_error(y_test, y_pred) |
24.275331856914285
RandomForestRegressor(max_depth=12, n_estimators=170, random_state=42)
| regr.feature_importances_ |
array([0.21877631, 0.7564047 , 0.01129733, 0.01352166])
| plt.figure(figsize=(15, 10)) |
| |
| t = np.arange(len(X_test)) |
| |
| plt.plot(t, y_test, 'r', linewidth=2, label='真实值') |
| |
| plt.plot(t, y_pred, 'g', linewidth=2, label='预测值') |
| |
| plt.legend() |
| plt.show() |

【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· Vue3状态管理终极指南:Pinia保姆级教程