使用GridSearchCV寻找最佳参数组合——机器学习工具箱代码

 

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# -*- coding: utf-8 -*-
import numpy as np
from sklearn.feature_extraction import FeatureHasher
from sklearn import datasets
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.neighbors import KNeighborsClassifier
import xgboost as xgb
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import train_test_split
from sklearn import metrics
from matplotlib import pyplot as plt
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import GridSearchCV
 
def report(test_Y, pred_Y):
    print("accuracy_score:")
    print(metrics.accuracy_score(test_Y, pred_Y))
    print("f1_score:")
    print(metrics.f1_score(test_Y, pred_Y))
    print("recall_score:")
    print(metrics.recall_score(test_Y, pred_Y))
    print("precision_score:")
    print(metrics.precision_score(test_Y, pred_Y))
    print("confusion_matrix:")
    print(metrics.confusion_matrix(test_Y, pred_Y))
    print("AUC:")
    print(metrics.roc_auc_score(test_Y, pred_Y))
 
    f_pos, t_pos, thresh = metrics.roc_curve(test_Y, pred_Y)
    auc_area = metrics.auc(f_pos, t_pos)
    plt.plot(f_pos, t_pos, 'darkorange', lw=2, label='AUC = %.2f' % auc_area)
    plt.legend(loc='lower right')
    plt.plot([0, 1], [0, 1], color='navy', linestyle='--')
    plt.title('ROC')
    plt.ylabel('True Pos Rate')
    plt.xlabel('False Pos Rate')
    plt.show()
 
 
 
if __name__== '__main__':
    x, y = datasets.make_classification(n_samples=1000, n_features=100,n_redundant=0, random_state = 1)
    train_X, test_X, train_Y, test_Y = train_test_split(x,
                                                        y,
                                                        test_size=0.2,
                                                        random_state=66)
    #clf = GradientBoostingClassifier(n_estimators=100)
    #clf.fit(train_X, train_Y)
    #pred_Y = clf.predict(test_X)
    #report(test_Y, pred_Y)
    scoring= "f1"
    parameters ={'n_estimators': range( 50, 200, 25), 'max_depth': range( 2, 10, 2)}
    gsearch = GridSearchCV(estimator= GradientBoostingClassifier(), param_grid= parameters, scoring='accuracy', iid= False, cv= 5)
    gsearch.fit(x, y)
    print("gsearch.best_params_")
    print(gsearch.best_params_)
    print("gsearch.best_score_")
    print(gsearch.best_score_)

 效果:

gsearch.best_params_
{'max_depth': 4, 'n_estimators': 100}
gsearch.best_score_
0.868142228555714

posted @   bonelee  阅读(655)  评论(0编辑  收藏  举报
编辑推荐:
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· DeepSeek 开源周回顾「GitHub 热点速览」
点击右上角即可分享
微信分享提示