模型调参---GridSearchCV
undefined
一. GridSearchCV参数介绍
导入模块:
1 | from sklearn.model_selection import GridSearchCV |
GridSearchCV 称为网格搜索交叉验证调参,它通过遍历传入的参数的所有排列组合,通过交叉验证的方式,返回所有参数组合下的评价指标得分,GridSearchCV 函数的参数详细解释如下:
1 | class sklearn.model_selection.GridSearchCV(estimator, param_grid, scoring = None , fit_params = None , n_jobs = 1 , iid = True , refit = True , cv = None , verbose = 0 , pre_dispatch = '2*n_jobs' , error_score = 'raise' , return_train_score = True ) |
参数:
- estimator:scikit-learn 库里的算法模型;
- param_grid:需要搜索调参的参数字典;
- scoring:评价指标,可以是 auc, rmse,logloss等;
- n_jobs:并行计算线程个数,可以设置为 -1,这样可以充分使用机器的所有处理器,并行数量越多,有利于缩短调参时间;
- iid:如果设置为True,则默认假设数据在每折中具有相同地分布,并且最小化的损失是每个样本的总损失,而不是每折的平均损失。简单点说,就是如果你可以确定 cv 中每折数据分布一致就设置为 True,否则设置为 False;
- cv:交叉验证的折数,默认为3折;
常用属性:
- cv_results_:用来输出cv结果的,可以是字典形式也可以是numpy形式,还可以转换成DataFrame格式
- best_estimator_:通过搜索参数得到的最好的估计器,当参数refit=False时该对象不可用
- best_score_:float类型,输出最好的成绩
- best_params_:通过网格搜索得到的score最好对应的参数
- best_index_:对应于最佳候选参数设置的索引(cv_results_数组)。cv_results _ [‘params’] [search.best_index_]中的dict给出了最佳模型的参数设置,给出了最高的平均分数(search.best_score_)。
- scorer_:评分函数
- n_splits_:交叉验证的数量
- refit_time_:refit所用的时间,当参数refit=False时该对象不可用
常用函数:
- decision_function(X):返回决策函数值(比如svm中的决策距离)
- fit(X,y=None,groups=None,fit_params):在数据集上运行所有的参数组合
- get_params(deep=True):返回估计器的参数
- inverse_transform(Xt):Call inverse_transform on the estimator with the best found params.
- predict(X):返回预测结果值(0/1)
- predict_log_proba(X): Call predict_log_proba on the estimator with the best found parameters.
- predict_proba(X):返回每个类别的概率值(有几类就返回几列值)
- score(X, y=None):返回函数
- set_params(**params):Set the parameters of this estimator.
- transform(X):在X上使用训练好的参数
属性grid_scores_已经被删除,改用:
means = grid_search.cv_results_[ 'mean_test_score' ] params = grid_search.cv_results_[ 'params' ] |
举例:
使用多评价指标,必须设置refit参数,可以显示多指标的结果,但是最后显示最佳的参数时候必须指定一个指标,详解:解决方法
param_test2 = { 'max_depth' :[ 3 , 4 , 5 , 6 ], 'min_child_weight' :[ 0.5 , 1 , 1.5 ]} scorers = { 'precision_score' : make_scorer(precision_score), 'recall_score' : make_scorer(recall_score), 'accuracy_score' : make_scorer(accuracy_score) } gsearch2 = GridSearchCV(estimator = XGBClassifier( learning_rate = 0.1 , n_estimators = 270 , max_depth = 4 ,min_child_weight = 1 , gamma = 0 , subsample = 0.8 ,\ colsample_bytree = 0.8 , objective = 'binary:logistic' , nthread = 4 ,\ scale_pos_weight = 1 , seed = 27 ),\ param_grid = param_test1,scoring = scorers,refit = 'precision_score' ,n_jobs = 4 ,iid = False , cv = 5 ) gsearch2.fit(x_train_resampled,y_train_resampled) |
查看最佳结果:
>>>gsearch2.best_params_,gsearch2.best_score_,gsearch2.cv_results_[ 'mean_test_precision_score' ],gsearch2.cv_results_[ 'params' ] ({ 'max_depth' : 9 , 'min_child_weight' : 1 }, 0.8278796760710192 , array([ 0.79985227 , 0.80330522 , 0.80645782 , 0.8223829 , 0.81170396 , 0.80891565 , 0.82691152 , 0.82032078 , 0.82220572 , 0.82787968 , 0.82439509 , 0.81863326 ]), [{ 'max_depth' : 3 , 'min_child_weight' : 1 }, { 'max_depth' : 3 , 'min_child_weight' : 3 }, { 'max_depth' : 3 , 'min_child_weight' : 5 }, { 'max_depth' : 5 , 'min_child_weight' : 1 }, { 'max_depth' : 5 , 'min_child_weight' : 3 }, { 'max_depth' : 5 , 'min_child_weight' : 5 }, { 'max_depth' : 7 , 'min_child_weight' : 1 }, { 'max_depth' : 7 , 'min_child_weight' : 3 }, { 'max_depth' : 7 , 'min_child_weight' : 5 }, { 'max_depth' : 9 , 'min_child_weight' : 1 }, { 'max_depth' : 9 , 'min_child_weight' : 3 }, { 'max_depth' : 9 , 'min_child_weight' : 5 }]) |
查看交叉验证的中间结果:
1 | pd.DataFrame(gsearch2.cv_results_) |
画图显示最佳参数:
grid_visualization = [] for grid_pair in gsearch2.cv_results_[ 'mean_test_precision_score' ]: grid_visualization.append(grid_pair) grid_visualization = np.array(grid_visualization) grid_visualization.shape = ( 4 , 3 ) sns.heatmap(grid_visualization,annot = True ,cmap = 'Blues' ,fmt = '.3f' ) plt.xticks(np.arange( 3 ) + 0.5 ,gsearch2.param_grid[ 'min_child_weight' ]) plt.yticks(np.arange( 4 ) + 0.5 ,gsearch2.param_grid[ 'max_depth' ]) plt.xlabel( 'min_child_weight' ) plt.ylabel( 'max_depth' ) |
参考文献:
【2】python机器学习库sklearn——参数优化(网格搜索GridSearchCV、随机搜索RandomizedSearchCV、hyperopt)
【4】使用GridSearchCV进行网格搜索(比较全)
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· go语言实现终端里的倒计时
· 如何编写易于单元测试的代码
· 10年+ .NET Coder 心语,封装的思维:从隐藏、稳定开始理解其本质意义
· .NET Core 中如何实现缓存的预热?
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 使用C#创建一个MCP客户端
· ollama系列1:轻松3步本地部署deepseek,普通电脑可用
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· 按钮权限的设计及实现