评估指标【交叉验证&ROC曲线】
1 # -*- coding: utf-8 -*- 2 """ 3 Created on Mon Sep 10 11:21:27 2018 4 5 @author: zhen 6 """ 7 from sklearn.datasets import fetch_mldata 8 import numpy as np 9 from sklearn.linear_model import SGDClassifier 10 from sklearn.model_selection import cross_val_score 11 from sklearn.model_selection import cross_val_predict 12 from sklearn.metrics import precision_recall_curve 13 import matplotlib 14 import matplotlib.pyplot as plt 15 from sklearn.metrics import roc_curve 16 from sklearn.metrics import roc_auc_score 17 from sklearn.ensemble import RandomForestClassifier 18 19 mnist = fetch_mldata('MNIST original', data_home='D:/AnalyseData学习资源库/人工智能开发/分类评估/资料/test_data_home') 20 21 x, y = mnist['data'], mnist['target'] 22 some_digit = x[36000] #获取第36000行数据 23 24 some_digit_image = some_digit.reshape(28, 28) 25 26 plt.imshow(some_digit_image, cmap=matplotlib.cm.binary, 27 interpolation='nearest', vmin=0, vmax=1) 28 plt.axis('off') 29 plt.show() 30 31 x_train, x_test, y_train, y_test = x[:60000], x[60000:], y[:60000], y[60000:] 32 shuffle_index = np.random.permutation(60000) 33 34 x_train, y_train = x_train[shuffle_index], y_train[shuffle_index] 35 36 y_train_5 = (y_train == 5) 37 y_test_5 = (y_test == 5) 38 39 sgd_clf = SGDClassifier(loss='log', random_state=42, max_iter=1000, tol=1e-4) 40 sgd_clf.fit(x_train, y_train_5) 41 42 result = sgd_clf.predict([some_digit]) 43 44 print(cross_val_score(sgd_clf, x_train, y_train_5, cv=3, scoring='accuracy')) 45 print(cross_val_score(sgd_clf, x_train, y_train_5, cv=3, scoring='precision')) 46 print(cross_val_score(sgd_clf, x_train, y_train_5, cv=3, scoring='recall')) 47 48 sgd_clf.fit(x_train, y_train_5) 49 50 y_scores = sgd_clf.decision_function([some_digit]) 51 52 threshold = 0 53 y_some_digit_pred = (y_scores > threshold) 54 55 threshold = 200000 56 y_some_digit_pred = (y_scores > threshold) 57 58 # cv 数据集划分的个数 59 y_scores = cross_val_predict(sgd_clf, x_train, y_train_5, cv=3, method='decision_function') 60 61 precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores) 62 63 64 def plot_precision_recall_vs_threshold(precisions, recalls, thresholds): 65 plt.plot(thresholds, precisions[:-1], 'b--',label='Precision') 66 plt.plot(thresholds, recalls[:-1], 'r--', label='Recall') 67 plt.xlabel("Threshold") 68 plt.legend(loc='upper left') 69 plt.ylim([0, 1]) 70 plt.show() 71 72 73 def plot_roc_curve(fpr, tpr, label=None): 74 plt.plot(fpr, tpr, linewidth=2, label='roc') 75 plt.plot([0, 1], [0, 1], 'k--', label='mid') 76 plt.legend(loc='lower right') 77 # plt.axes([0, 1, 0, 1]) : 前两个参数表示坐标原点的位置,后两个表示x,y轴的长度 78 plt.xlabel('fpr') 79 plt.ylabel('tpr') 80 plt.show() 81 82 83 plot_precision_recall_vs_threshold(precisions, recalls, thresholds) 84 85 fpr, tpr, thresholds = roc_curve(y_train_5, y_scores) 86 plot_roc_curve(fpr, tpr) 87 88 print(roc_auc_score(y_train_5, y_scores)) 89 90 forest_clf = RandomForestClassifier(random_state=42) 91 y_probas_forest = cross_val_predict(forest_clf, x_train, y_train_5, cv=3, method='predict_proba') 92 y_scores_forest = y_probas_forest[:, 1] 93 fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5, y_scores_forest) 94 plt.plot(fpr, tpr, 'b:', label='SGD') 95 plt.plot(fpr_forest, tpr_forest, label='Random Forest') 96 plt.legend(loc='lower right') 97 plt.show() 98 99 print(roc_auc_score(y_train_5, y_scores_forest))
总结:正向准确率和召回率在整体上成反比,可知在使用相同数据集,相同验证方式的情况下,随机森林要优于随机梯度下降!