评估指标【交叉验证&ROC曲线】

 1 # -*- coding: utf-8 -*-
 2 """
 3 Created on Mon Sep 10 11:21:27 2018
 4 
 5 @author: zhen
 6 """
 7 from sklearn.datasets import fetch_mldata
 8 import numpy as np
 9 from sklearn.linear_model import SGDClassifier
10 from sklearn.model_selection import cross_val_score
11 from sklearn.model_selection import cross_val_predict
12 from sklearn.metrics import precision_recall_curve
13 import matplotlib
14 import matplotlib.pyplot as plt
15 from sklearn.metrics import roc_curve
16 from sklearn.metrics import roc_auc_score
17 from sklearn.ensemble import RandomForestClassifier
18 
19 mnist = fetch_mldata('MNIST original', data_home='D:/AnalyseData学习资源库/人工智能开发/分类评估/资料/test_data_home')
20 
21 x, y = mnist['data'], mnist['target']
22 some_digit = x[36000]  #获取第36000行数据
23 
24 some_digit_image = some_digit.reshape(28, 28)
25 
26 plt.imshow(some_digit_image, cmap=matplotlib.cm.binary,
27            interpolation='nearest', vmin=0, vmax=1)
28 plt.axis('off')
29 plt.show()
30 
31 x_train, x_test, y_train, y_test = x[:60000], x[60000:], y[:60000], y[60000:]
32 shuffle_index = np.random.permutation(60000)
33 
34 x_train, y_train = x_train[shuffle_index], y_train[shuffle_index]
35 
36 y_train_5 = (y_train == 5)
37 y_test_5 = (y_test == 5)
38 
39 sgd_clf = SGDClassifier(loss='log', random_state=42, max_iter=1000, tol=1e-4)
40 sgd_clf.fit(x_train, y_train_5)  
41 
42 result = sgd_clf.predict([some_digit])
43 
44 print(cross_val_score(sgd_clf, x_train, y_train_5, cv=3, scoring='accuracy'))
45 print(cross_val_score(sgd_clf, x_train, y_train_5, cv=3, scoring='precision'))
46 print(cross_val_score(sgd_clf, x_train, y_train_5, cv=3, scoring='recall'))
47 
48 sgd_clf.fit(x_train, y_train_5)
49 
50 y_scores = sgd_clf.decision_function([some_digit])
51 
52 threshold = 0
53 y_some_digit_pred = (y_scores > threshold)
54 
55 threshold = 200000
56 y_some_digit_pred = (y_scores > threshold)
57 
58 # cv 数据集划分的个数
59 y_scores = cross_val_predict(sgd_clf, x_train, y_train_5, cv=3, method='decision_function')
60 
61 precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)
62 
63 
64 def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
65     plt.plot(thresholds, precisions[:-1], 'b--',label='Precision')
66     plt.plot(thresholds, recalls[:-1], 'r--', label='Recall')
67     plt.xlabel("Threshold")
68     plt.legend(loc='upper left')
69     plt.ylim([0, 1])
70     plt.show()  
71     
72     
73 def plot_roc_curve(fpr, tpr, label=None):
74     plt.plot(fpr, tpr, linewidth=2, label='roc')
75     plt.plot([0, 1], [0, 1], 'k--', label='mid')
76     plt.legend(loc='lower right')
77     # plt.axes([0, 1, 0, 1]) : 前两个参数表示坐标原点的位置,后两个表示x,y轴的长度
78     plt.xlabel('fpr')
79     plt.ylabel('tpr')
80     plt.show()  
81 
82 
83 plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
84 
85 fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)
86 plot_roc_curve(fpr, tpr)
87    
88 print(roc_auc_score(y_train_5, y_scores))
89 
90 forest_clf = RandomForestClassifier(random_state=42)
91 y_probas_forest = cross_val_predict(forest_clf, x_train, y_train_5, cv=3, method='predict_proba')
92 y_scores_forest = y_probas_forest[:, 1]
93 fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5, y_scores_forest)
94 plt.plot(fpr, tpr, 'b:', label='SGD')
95 plt.plot(fpr_forest, tpr_forest, label='Random Forest')
96 plt.legend(loc='lower right')
97 plt.show()
98 
99 print(roc_auc_score(y_train_5, y_scores_forest))

          

总结:正向准确率和召回率在整体上成反比,可知在使用相同数据集,相同验证方式的情况下,随机森林要优于随机梯度下降!

posted @ 2018-09-10 16:20  云山之巅  阅读(3611)  评论(0编辑  收藏  举报