评估指标【交叉验证&ROC曲线】
1 # -*- coding: utf-8 -*- 2 """ 3 Created on Mon Sep 10 11:21:27 2018 4 5 @author: zhen 6 """ 7 from sklearn.datasets import fetch_mldata 8 import numpy as np 9 from sklearn.linear_model import SGDClassifier 10 from sklearn.model_selection import cross_val_score 11 from sklearn.model_selection import cross_val_predict 12 from sklearn.metrics import precision_recall_curve 13 import matplotlib 14 import matplotlib.pyplot as plt 15 from sklearn.metrics import roc_curve 16 from sklearn.metrics import roc_auc_score 17 from sklearn.ensemble import RandomForestClassifier 18 19 mnist = fetch_mldata('MNIST original', data_home='D:/AnalyseData学习资源库/人工智能开发/分类评估/资料/test_data_home') 20 21 x, y = mnist['data'], mnist['target'] 22 some_digit = x[36000] #获取第36000行数据 23 24 some_digit_image = some_digit.reshape(28, 28) 25 26 plt.imshow(some_digit_image, cmap=matplotlib.cm.binary, 27 interpolation='nearest', vmin=0, vmax=1) 28 plt.axis('off') 29 plt.show() 30 31 x_train, x_test, y_train, y_test = x[:60000], x[60000:], y[:60000], y[60000:] 32 shuffle_index = np.random.permutation(60000) 33 34 x_train, y_train = x_train[shuffle_index], y_train[shuffle_index] 35 36 y_train_5 = (y_train == 5) 37 y_test_5 = (y_test == 5) 38 39 sgd_clf = SGDClassifier(loss='log', random_state=42, max_iter=1000, tol=1e-4) 40 sgd_clf.fit(x_train, y_train_5) 41 42 result = sgd_clf.predict([some_digit]) 43 44 print(cross_val_score(sgd_clf, x_train, y_train_5, cv=3, scoring='accuracy')) 45 print(cross_val_score(sgd_clf, x_train, y_train_5, cv=3, scoring='precision')) 46 print(cross_val_score(sgd_clf, x_train, y_train_5, cv=3, scoring='recall')) 47 48 sgd_clf.fit(x_train, y_train_5) 49 50 y_scores = sgd_clf.decision_function([some_digit]) 51 52 threshold = 0 53 y_some_digit_pred = (y_scores > threshold) 54 55 threshold = 200000 56 y_some_digit_pred = (y_scores > threshold) 57 58 # cv 数据集划分的个数 59 y_scores = cross_val_predict(sgd_clf, x_train, y_train_5, cv=3, method='decision_function') 60 61 precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores) 62 63 64 def plot_precision_recall_vs_threshold(precisions, recalls, thresholds): 65 plt.plot(thresholds, precisions[:-1], 'b--',label='Precision') 66 plt.plot(thresholds, recalls[:-1], 'r--', label='Recall') 67 plt.xlabel("Threshold") 68 plt.legend(loc='upper left') 69 plt.ylim([0, 1]) 70 plt.show() 71 72 73 def plot_roc_curve(fpr, tpr, label=None): 74 plt.plot(fpr, tpr, linewidth=2, label='roc') 75 plt.plot([0, 1], [0, 1], 'k--', label='mid') 76 plt.legend(loc='lower right') 77 # plt.axes([0, 1, 0, 1]) : 前两个参数表示坐标原点的位置,后两个表示x,y轴的长度 78 plt.xlabel('fpr') 79 plt.ylabel('tpr') 80 plt.show() 81 82 83 plot_precision_recall_vs_threshold(precisions, recalls, thresholds) 84 85 fpr, tpr, thresholds = roc_curve(y_train_5, y_scores) 86 plot_roc_curve(fpr, tpr) 87 88 print(roc_auc_score(y_train_5, y_scores)) 89 90 forest_clf = RandomForestClassifier(random_state=42) 91 y_probas_forest = cross_val_predict(forest_clf, x_train, y_train_5, cv=3, method='predict_proba') 92 y_scores_forest = y_probas_forest[:, 1] 93 fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5, y_scores_forest) 94 plt.plot(fpr, tpr, 'b:', label='SGD') 95 plt.plot(fpr_forest, tpr_forest, label='Random Forest') 96 plt.legend(loc='lower right') 97 plt.show() 98 99 print(roc_auc_score(y_train_5, y_scores_forest))
总结:正向准确率和召回率在整体上成反比,可知在使用相同数据集,相同验证方式的情况下,随机森林要优于随机梯度下降!
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 10年+ .NET Coder 心语,封装的思维:从隐藏、稳定开始理解其本质意义
· .NET Core 中如何实现缓存的预热?
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· 10年+ .NET Coder 心语 ── 封装的思维:从隐藏、稳定开始理解其本质意义
· 地球OL攻略 —— 某应届生求职总结
· 提示词工程——AI应用必不可少的技术
· Open-Sora 2.0 重磅开源!
· 周边上新:园子的第一款马克杯温暖上架