使用GridSearchCV寻找最佳参数组合——机器学习工具箱代码
# -*- coding: utf-8 -*- import numpy as np from sklearn.feature_extraction import FeatureHasher from sklearn import datasets from sklearn.ensemble import GradientBoostingClassifier from sklearn.neighbors import KNeighborsClassifier import xgboost as xgb from sklearn.model_selection import GridSearchCV from sklearn.model_selection import train_test_split from sklearn import metrics from matplotlib import pyplot as plt from sklearn.ensemble import GradientBoostingClassifier from sklearn.model_selection import GridSearchCV def report(test_Y, pred_Y): print("accuracy_score:") print(metrics.accuracy_score(test_Y, pred_Y)) print("f1_score:") print(metrics.f1_score(test_Y, pred_Y)) print("recall_score:") print(metrics.recall_score(test_Y, pred_Y)) print("precision_score:") print(metrics.precision_score(test_Y, pred_Y)) print("confusion_matrix:") print(metrics.confusion_matrix(test_Y, pred_Y)) print("AUC:") print(metrics.roc_auc_score(test_Y, pred_Y)) f_pos, t_pos, thresh = metrics.roc_curve(test_Y, pred_Y) auc_area = metrics.auc(f_pos, t_pos) plt.plot(f_pos, t_pos, 'darkorange', lw=2, label='AUC = %.2f' % auc_area) plt.legend(loc='lower right') plt.plot([0, 1], [0, 1], color='navy', linestyle='--') plt.title('ROC') plt.ylabel('True Pos Rate') plt.xlabel('False Pos Rate') plt.show() if __name__== '__main__': x, y = datasets.make_classification(n_samples=1000, n_features=100,n_redundant=0, random_state = 1) train_X, test_X, train_Y, test_Y = train_test_split(x, y, test_size=0.2, random_state=66) #clf = GradientBoostingClassifier(n_estimators=100) #clf.fit(train_X, train_Y) #pred_Y = clf.predict(test_X) #report(test_Y, pred_Y) scoring= "f1" parameters ={'n_estimators': range( 50, 200, 25), 'max_depth': range( 2, 10, 2)} gsearch = GridSearchCV(estimator= GradientBoostingClassifier(), param_grid= parameters, scoring='accuracy', iid= False, cv= 5) gsearch.fit(x, y) print("gsearch.best_params_") print(gsearch.best_params_) print("gsearch.best_score_") print(gsearch.best_score_)
效果:
gsearch.best_params_
{'max_depth': 4, 'n_estimators': 100}
gsearch.best_score_
0.868142228555714