随笔 - 168, 文章 - 0, 评论 - 10, 阅读 - 35万

导航

< 2025年1月 >
29 30 31 1 2 3 4
5 6 7 8 9 10 11
12 13 14 15 16 17 18
19 20 21 22 23 24 25
26 27 28 29 30 31 1
2 3 4 5 6 7 8

sklearn的GridSearchCV例子

Posted on   wzd321  阅读(685)  评论(0编辑  收藏  举报
class sklearn.model_selection.GridSearchCV(estimator, param_grid, scoring=None, fit_params=None, n_jobs=1, iid=True, refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs', error_score='raise', return_train_score=True)

1.estimator:

  传入估计器与不需要调参的参数,每一个估计器都需要一个scoring参数。

2.param_grid:

  需要最优化的参数的取值,值为字典或者列表。

3.scoring:

  模型评价标准,默认None,这时需要使用score函数,根据所选模型不同,评价准则不同。字符串或者自定义形如:scorer(estimator, X, y);如果是None,则使用estimator的误差估计函数。

4.n_jobs

  n_jobs: 并行数,int:个数,-1:跟CPU核数一致。

5.refit=True

  默认为True,程序将会以交叉验证训练集得到的最佳参数,重新对所有可用的训练集与开发集进行,作为最终用于性能评估的最佳模型参数。即在搜索参数结束后,用最佳参数结果再次fit一遍全部数据集。

6.pre_dispatch=‘2*n_jobs’

  指定总共分发的并行任务数。当n_jobs大于1时,数据将在每个运行点进行复制,这可能导致OOM,而设置pre_dispatch参数,则可以预先划分总共的job数量,使数据最多被复制pre_dispatch次。

复制代码
from sklearn.datasets import load_iris
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import classification_report

X,y = load_iris(return_X_y=True)  
df_X = pd.DataFrame(X,columns=list("ABCD"))


#gridSearchCV
parameters = [{'n_estimators':[10,100,1000],
               'criterion':['entropy','gini'],
               'max_depth':[10,50,100,200],
               'min_samples_split':[2,5,10],
               'min_weight_fraction_leaf':[0.0,0.1,0.2,0.3,0.4,0.5]}]

parameters = [{'n_estimators':[10,20]}]

#scoring="precision"或者"recall"或者"roc_auc","accuracy"或者None

clf = GridSearchCV(RandomForestClassifier(), parameters,cv=2,scoring="accuracy")
clf.fit(df_X,y)

clf.cv_results_
# =============================================================================
# {'mean_fit_time': array([0.0089916 , 0.01695275]),
#  'mean_score_time': array([0.00099409, 0.00148273]),
#  'mean_test_score': array([0.94666667, 0.96      ]),
#  'mean_train_score': array([0.98666667, 1.        ]),
#  'param_n_estimators': masked_array(data=[10, 20],
#               mask=[False, False],
#         fill_value='?',
#              dtype=object),
#  'params': [{'n_estimators': 10}, {'n_estimators': 20}],
#  'rank_test_score': array([2, 1]),
#  'split0_test_score': array([0.96, 0.96]),
#  'split0_train_score': array([1., 1.]),
#  'split1_test_score': array([0.93333333, 0.96      ]),
#  'split1_train_score': array([0.97333333, 1.        ]),
#  'std_fit_time': array([1.01363659e-03, 9.53674316e-07]),
#  'std_score_time': array([4.17232513e-06, 5.05685806e-04]),
#  'std_test_score': array([0.01333333, 0.        ]),
#  'std_train_score': array([0.01333333, 0.        ])}
# =============================================================================
clf.best_estimator_
# =============================================================================
# RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
#             max_depth=None, max_features='auto', max_leaf_nodes=None,
#             min_impurity_decrease=0.0, min_impurity_split=None,
#             min_samples_leaf=1, min_samples_split=2,
#             min_weight_fraction_leaf=0.0, n_estimators=20, n_jobs=1,
#             oob_score=False, random_state=None, verbose=0,
#             warm_start=False)
# =============================================================================

clf.best_score_
# =============================================================================
# Out[42]: 0.96
# 
# =============================================================================

clf.best_params_

# =============================================================================
# Out[43]: {'n_estimators': 20}
# 
# =============================================================================
clf.grid_scores_

# =============================================================================
# [mean: 0.94667, std: 0.01333, params: {'n_estimators': 10},
#  mean: 0.96000, std: 0.00000, params: {'n_estimators': 20}]
# =============================================================================
复制代码

 

 

 

 

 

 

 

 

 

 

 

 

 

参考:http://blog.51cto.com/emily18/2088128

 

编辑推荐:
· .NET Core GC压缩(compact_phase)底层原理浅谈
· 现代计算机视觉入门之:什么是图片特征编码
· .NET 9 new features-C#13新的锁类型和语义
· Linux系统下SQL Server数据库镜像配置全流程详解
· 现代计算机视觉入门之:什么是视频
阅读排行:
· 【译】我们最喜欢的2024年的 Visual Studio 新功能
· 个人数据保全计划:从印象笔记迁移到joplin
· Vue3.5常用特性整理
· 重拾 SSH:从基础到安全加固
· 为什么UNIX使用init进程启动其他进程?
点击右上角即可分享
微信分享提示