ROC曲线绘制
ROC曲线绘制调用
源码:
View Code
View Code
源码:

import numpy as np import matplotlib.pyplot as plt from itertools import cycle from sklearn.metrics import roc_curve, auc from sklearn.preprocessing import label_binarize ''' 当报错显示:predict_proba is not available when probability=False时, 则为分类器probability参数被设置为False,导致不能计算预测概率,一些分类器(如SVM)中,predict_proba默认是禁用的 分类器中启用probability=True,例如SVM(probability=True) ''' def plot_multiclass_roc(clf, x_test, y_test, n_classes): """ 用于绘制多类分类算法的 ROC 曲线,包括 micro-averaging 和 macro-averaging 参数: clf: 已训练的分类器 x_test: 测试数据 y_test: 测试标签 n_classes: 类别数量 """ # 获得分类器的预测概率 y_score = clf.predict_proba(x_test) # 将类别标签进行二进制编码,以便用于 ROC 曲线计算 y_test_bin = label_binarize(y_test, classes=np.arange(n_classes)) # 计算每个类别的 ROC 曲线 fpr = dict() tpr = dict() roc_auc = dict() for i in range(n_classes): fpr[i], tpr[i], _ = roc_curve(y_test_bin[:, i], y_score[:, i]) roc_auc[i] = auc(fpr[i], tpr[i]) # 计算 micro 和 macro 平均 AUC 值 fpr["micro"], tpr["micro"], _ = roc_curve(y_test_bin.ravel(), y_score.ravel()) roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) # 将所有假正类率汇总为一个集合 all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)])) # 对所有 ROC 曲线进行插值,然后求平均 mean_tpr = np.zeros_like(all_fpr) for i in range(n_classes): mean_tpr += np.interp(all_fpr, fpr[i], tpr[i]) mean_tpr /= n_classes # 计算 macro 平均值 fpr["macro"] = all_fpr tpr["macro"] = mean_tpr roc_auc["macro"] = auc(fpr["macro"], tpr["macro"]) # 绘制 ROC 曲线 plt.figure() lw = 2 colors = cycle(['deeppink', 'aqua', 'darkorange', 'red', 'navy', 'magenta', 'green', 'cyan', 'cornflowerblue']) for i, color in zip(range(n_classes), colors): plt.plot(fpr[i], tpr[i], color=color, lw=lw, label='ROC curve of class {0} (area = {1:0.2f})' ''.format(i, roc_auc[i])) plt.plot(fpr["micro"], tpr["micro"], label='micro-average ROC curve (area = {0:0.2f})' ''.format(roc_auc["micro"]), color='deeppink', linestyle=':', linewidth=4) plt.plot(fpr["macro"], tpr["macro"], label='macro-average ROC curve (area = {0:0.2f})' ''.format(roc_auc["macro"]), color='navy', linestyle=':', linewidth=4) plt.plot([0, 1], [0, 1], 'k--', lw=lw) plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('Receiver operating characteristic for multi-class') plt.legend(loc="lower right") plt.show()
这里需要注意的是:
当报错显示:predict_proba is not available when probability=False时,调用:
则为分类器probability参数被设置为False,导致不能计算预测概率,一些分类器(如SVM)中,predict_proba默认是禁用的
分类器中启用probability=True,例如SVM(probability=True)

from sklearn.datasets import load_wine from sklearn.model_selection import train_test_split from sklearn.svm import SVC from roc import plot_multiclass_roc # 加载葡萄酒数据集 data = load_wine() X = data.data y = data.target # 分割数据集为训练集和测试集 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=421) # 创建决策树分类器 clf = SVC(gamma='scale',kernel='linear',random_state=421,C=0.32,probability=True) #probability=True与此调用有关 clf.fit(X_train, y_train) plot_multiclass_roc(clf, X_test, y_test, n_classes=3)
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!