机器学习之调参

导入数据:

from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold 
from sklearn.datasets import load_wine

wine = load_wine()
X = wine.data
y = wine.target

#splitting the data into train and test set
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size = 0.3,random_state = 14)

调参方法:

1、网格搜索:

from sklearn.model_selection import GridSearchCV

knn = KNeighborsClassifier()
grid_param = { 'n_neighbors' : list(range(2,11)) , 
              'algorithm' : ['auto','ball_tree','kd_tree','brute'] }
              
grid = GridSearchCV(knn,grid_param,cv = 5)
grid.fit(X_train,y_train)

#best parameter combination
grid.best_params_  #{'algorithm': 'auto', 'n_neighbors': 5}

#Score achieved with best parameter combination
grid.best_score_  #0.774

#all combinations of hyperparameters
grid.cv_results_['params']

#average scores of cross-validation
grid.cv_results_['mean_test_score']

 

2、贝叶斯搜索:

from skopt import BayesSearchCV

import warnings
warnings.filterwarnings("ignore")

# parameter ranges are specified by one of below
from skopt.space import Real, Categorical, Integer

knn = KNeighborsClassifier()
#defining hyper-parameter grid
grid_param = { 'n_neighbors' : list(range(2,11)) , 
              'algorithm' : ['auto','ball_tree','kd_tree','brute'] }

#initializing Bayesian Search
Bayes = BayesSearchCV(knn , grid_param , n_iter=30 , random_state=14)
Bayes.fit(X_train,y_train)

#best parameter combination
Bayes.best_params_  #OrderedDict([('algorithm', 'ball_tree'), ('n_neighbors', 5)])

#score achieved with best parameter combination
Bayes.best_score_  #0.7741935483870968

#all combinations of hyperparameters
Bayes.cv_results_['params']

#average scores of cross-validation
Bayes.cv_results_['mean_test_score']

 

网格搜索缺点:由于它尝试了超参数的每一个组合,并根据K折交叉验证得分选择了最佳组合,这使得GridsearchCV非常慢。

贝叶斯搜索缺点:要在2维或3维的搜索空间中得到一个好的代理曲面需要十几个样本,增加搜索空间的维数需要更多的样本。

 

除此之外还有传统手工搜索及随机搜索,未使用过,不推荐。

posted @ 2020-11-30 21:00  17孤独的牧羊人  阅读(115)  评论(0编辑  收藏  举报