使用GridSearchCV寻找最佳参数组合——机器学习工具箱代码

 

 

# -*- coding: utf-8 -*-
import numpy as np
from sklearn.feature_extraction import FeatureHasher
from sklearn import datasets
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.neighbors import KNeighborsClassifier
import xgboost as xgb
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import train_test_split
from sklearn import metrics
from matplotlib import pyplot as plt
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import GridSearchCV

def report(test_Y, pred_Y):
    print("accuracy_score:")
    print(metrics.accuracy_score(test_Y, pred_Y))
    print("f1_score:")
    print(metrics.f1_score(test_Y, pred_Y))
    print("recall_score:")
    print(metrics.recall_score(test_Y, pred_Y))
    print("precision_score:")
    print(metrics.precision_score(test_Y, pred_Y))
    print("confusion_matrix:")
    print(metrics.confusion_matrix(test_Y, pred_Y))
    print("AUC:")
    print(metrics.roc_auc_score(test_Y, pred_Y))

    f_pos, t_pos, thresh = metrics.roc_curve(test_Y, pred_Y)
    auc_area = metrics.auc(f_pos, t_pos)
    plt.plot(f_pos, t_pos, 'darkorange', lw=2, label='AUC = %.2f' % auc_area)
    plt.legend(loc='lower right')
    plt.plot([0, 1], [0, 1], color='navy', linestyle='--')
    plt.title('ROC')
    plt.ylabel('True Pos Rate')
    plt.xlabel('False Pos Rate')
    plt.show()



if __name__== '__main__':
    x, y = datasets.make_classification(n_samples=1000, n_features=100,n_redundant=0, random_state = 1)
    train_X, test_X, train_Y, test_Y = train_test_split(x,
                                                        y,
                                                        test_size=0.2,
                                                        random_state=66)
    #clf = GradientBoostingClassifier(n_estimators=100)
    #clf.fit(train_X, train_Y)
    #pred_Y = clf.predict(test_X)
    #report(test_Y, pred_Y)
    scoring= "f1"
    parameters ={'n_estimators': range( 50, 200, 25), 'max_depth': range( 2, 10, 2)}
    gsearch = GridSearchCV(estimator= GradientBoostingClassifier(), param_grid= parameters, scoring='accuracy', iid= False, cv= 5) 
    gsearch.fit(x, y)
    print("gsearch.best_params_") 
    print(gsearch.best_params_) 
    print("gsearch.best_score_") 
    print(gsearch.best_score_)

 效果:

gsearch.best_params_
{'max_depth': 4, 'n_estimators': 100}
gsearch.best_score_
0.868142228555714

posted @ 2018-06-08 10:09  bonelee  阅读(649)  评论(0编辑  收藏  举报