GridSearchCV网格搜索得到最佳超参数, 在K近邻算法中的应用

  最近在学习机器学习中的K近邻算法, KNeighborsClassifier 看似简单实则里面有很多的参数配置, 这些参数直接影响到预测的准确率. 很自然的问题就是如何找到最优参数配置? 这就需要用到GridSearchCV 网格搜索模型. 

  在没有学习到GridSearchCV 网格搜索模型之前, 寻找最优参数配置是通过人为改变参数, 来观察预测结果准确率的. 具体步骤如下:

  1. 修改参数配置
  2. fit 训练集
  3. 预测测试集
  4. 预测结果与真实结果对比
  5. 重复上述步骤

  GridSearchCV 网格搜索模型寻找最优参数的步骤如下:

  1. 将各种参数配置封装为列表
  2. 实例化分类器
  3. 使用GridSearchCV 为分类器和参数建模
  4. 实例化模型, 并用新的模型对象fit训练集
  5. 得到最好的参数配置
  6. 用最优参数去预测数据

  于是我的疑问就来了, GridSearchCV 并没有去预测测试集,进而得到预测结果,并在与真实结果的对比中找到最优的参数配置, 没有这个步骤,它是怎么得到最优参数的? 搜索了很多,终于在这个网页中得到了想要的信息: python – GridSearchCV是否执行交叉验证? http://www.cocoachina.com/articles/67515 

  简单说就是我们把训练集传递给GridSearchCV, 它会进一步将训练集分为训练集和测试集, 然后通过不断调整超参数, 进行交叉验证, 最后获得最优参数. 

  GridSearchCV会主动将数据分为训练集和测试集,这就是原因所在了.

  代码实现:

 1 from sklearn import datasets
 2 from sklearn.model_selection import train_test_split
 3 from sklearn.neighbors import KNeighborsClassifier
 4 from sklearn.metrics import accuracy_score
 5 from sklearn.model_selection import GridSearchCV
 6 
 7 
 8 # 1/获取数据
 9 digits = datasets.load_digits()
10 X = digits.data
11 y = digits.target
12 
13 # 2/分割数据,得到训练集和测试集
14 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)
15 
16 
17 # 3/超参数配置
18 param_grid = [
19     {
20         "weights":["uniform"],
21         "n_neighbors":[i for i in range(1,11)]
22     },
23     {
24         "weights":["distance"],
25         "n_neighbors":[i for i in range(1,11)],
26         "p":[i for i in range(1,6)]
27     }
28 ]
29 
30 
31 # 4/为分类器和超参数搭建模型
32 knn_clf = KNeighborsClassifier()
33 grid_search = GridSearchCV(knn_clf, param_grid, n_jobs=-1, verbose=2)
34 
35 # 5/实例化模型(多种参数配置的分类器)fit训练集,
36 # 本质上是将训练集进一步分为训练集和测试集,得到最好的参数配置
37 # 因为要不断尝试各种参数交叉验证,所以非常耗时
38 grid_search.fit(X_train, y_train)
39 
40 # 6/
41 # 最终拿到最佳参数配置分类器 best_estimator_
42 knn_clf = grid_search.best_estimator_
43 
44 # 7/使用最佳分类器对测试集预测
45 y_predict = knn_clf.predict(X_test)
46 
47 # 8/得到准确率
48 accuracy_score(y_test, y_predict))

 

posted @ 2020-07-29 05:44  止一  阅读(910)  评论(0编辑  收藏  举报