StratifiedKFold与GridSearchCV版本前后使用方法
首先在sklearn官网上你可以看到:
所以,旧版本import时:
from sklearn.cross_validation import GridSearchCV
新版本import时:
from sklearn.model_selection import GridSearchCV
StratifiedKFold同样是这个问题,我用的是pycharm,IDE会自动提示这一点。
<----------------------------------分割线------------------------------------------->
之前版本StratifiedKFold与GridSearchCV的结合使用代码如下:
比如我用的是决策树
from sklearn.grid_search import GridSearchCV from sklearn.cross_validation import StratifiedKFold decision_tree_classifier = DecisionTreeClassifier() parameter_grid = {'max_depth': [1, 2, 3, 4, 5], 'max_features': [1, 2, 3, 4]} cross_validation = StratifiedKFold(all_classes, n_folds=10) grid_search = GridSearchCV(decision_tree_classifier, param_grid=parameter_grid, cv=cross_validation) grid_search.fit(all_inputs, all_classes) print('Best score: {}'.format(grid_search.best_score_)) print('Best parameters: {}'.format(grid_search.best_params_))
版本升级后,StratifiedKFold与GridSearchCV的结合使用代码如下:
from sklearn.model_selection import GridSearchCV from sklearn.model_selection import StratifiedKFold decision_tree_classifier = DecisionTreeClassifier() parameter_grid = {'max_depth': [1, 2, 3, 4, 5], 'max_features': [1, 2, 3, 4]} skf = StratifiedKFold(n_splits=10) cross_validation = skf.get_n_splits(all_inputs, all_classes) grid_search = GridSearchCV(decision_tree_classifier, param_grid=parameter_grid,cv=cross_validation) grid_search.fit(all_inputs, all_classes) print("Best score:", grid_search.best_score_) print("Best param:", grid_search.best_params_)
对比代码,你会发现 StratifiedKFold()参数不同了,更多信息请参考sklearn官网文档。