使用GridSearchCV寻找最佳参数组合——机器学习工具箱代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | # -*- coding: utf-8 -*- import numpy as np from sklearn.feature_extraction import FeatureHasher from sklearn import datasets from sklearn.ensemble import GradientBoostingClassifier from sklearn.neighbors import KNeighborsClassifier import xgboost as xgb from sklearn.model_selection import GridSearchCV from sklearn.model_selection import train_test_split from sklearn import metrics from matplotlib import pyplot as plt from sklearn.ensemble import GradientBoostingClassifier from sklearn.model_selection import GridSearchCV def report(test_Y, pred_Y): print ( "accuracy_score:" ) print (metrics.accuracy_score(test_Y, pred_Y)) print ( "f1_score:" ) print (metrics.f1_score(test_Y, pred_Y)) print ( "recall_score:" ) print (metrics.recall_score(test_Y, pred_Y)) print ( "precision_score:" ) print (metrics.precision_score(test_Y, pred_Y)) print ( "confusion_matrix:" ) print (metrics.confusion_matrix(test_Y, pred_Y)) print ( "AUC:" ) print (metrics.roc_auc_score(test_Y, pred_Y)) f_pos, t_pos, thresh = metrics.roc_curve(test_Y, pred_Y) auc_area = metrics.auc(f_pos, t_pos) plt.plot(f_pos, t_pos, 'darkorange' , lw = 2 , label = 'AUC = %.2f' % auc_area) plt.legend(loc = 'lower right' ) plt.plot([ 0 , 1 ], [ 0 , 1 ], color = 'navy' , linestyle = '--' ) plt.title( 'ROC' ) plt.ylabel( 'True Pos Rate' ) plt.xlabel( 'False Pos Rate' ) plt.show() if __name__ = = '__main__' : x, y = datasets.make_classification(n_samples = 1000 , n_features = 100 ,n_redundant = 0 , random_state = 1 ) train_X, test_X, train_Y, test_Y = train_test_split(x, y, test_size = 0.2 , random_state = 66 ) #clf = GradientBoostingClassifier(n_estimators=100) #clf.fit(train_X, train_Y) #pred_Y = clf.predict(test_X) #report(test_Y, pred_Y) scoring = "f1" parameters = { 'n_estimators' : range ( 50 , 200 , 25 ), 'max_depth' : range ( 2 , 10 , 2 )} gsearch = GridSearchCV(estimator = GradientBoostingClassifier(), param_grid = parameters, scoring = 'accuracy' , iid = False , cv = 5 ) gsearch.fit(x, y) print ( "gsearch.best_params_" ) print (gsearch.best_params_) print ( "gsearch.best_score_" ) print (gsearch.best_score_) |
效果:
gsearch.best_params_
{'max_depth': 4, 'n_estimators': 100}
gsearch.best_score_
0.868142228555714
标签:
机器学习
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· DeepSeek 开源周回顾「GitHub 热点速览」