ROC曲线绘制

ROC曲线绘制调用
源码:
复制代码
import numpy as np
import matplotlib.pyplot as plt
from itertools import cycle
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize

'''
当报错显示:predict_proba is not available when  probability=False时,
则为分类器probability参数被设置为False,导致不能计算预测概率,一些分类器(如SVM)中,predict_proba默认是禁用的
分类器中启用probability=True,例如SVM(probability=True)
'''


def plot_multiclass_roc(clf, x_test, y_test, n_classes):
    """
    用于绘制多类分类算法的 ROC 曲线,包括 micro-averaging 和 macro-averaging

    参数:
    clf: 已训练的分类器
    x_test: 测试数据
    y_test: 测试标签
    n_classes: 类别数量
    """
    # 获得分类器的预测概率
    y_score = clf.predict_proba(x_test)

    # 将类别标签进行二进制编码,以便用于 ROC 曲线计算
    y_test_bin = label_binarize(y_test, classes=np.arange(n_classes))

    # 计算每个类别的 ROC 曲线
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_test_bin[:, i], y_score[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # 计算 micro 和 macro 平均 AUC 值
    fpr["micro"], tpr["micro"], _ = roc_curve(y_test_bin.ravel(), y_score.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

    # 将所有假正类率汇总为一个集合
    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))

    # 对所有 ROC 曲线进行插值,然后求平均
    mean_tpr = np.zeros_like(all_fpr)
    for i in range(n_classes):
        mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
    mean_tpr /= n_classes

    # 计算 macro 平均值
    fpr["macro"] = all_fpr
    tpr["macro"] = mean_tpr
    roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

    # 绘制 ROC 曲线
    plt.figure()
    lw = 2
    colors = cycle(['deeppink', 'aqua', 'darkorange', 'red', 'navy', 'magenta', 'green', 'cyan', 'cornflowerblue'])
    for i, color in zip(range(n_classes), colors):
        plt.plot(fpr[i], tpr[i], color=color, lw=lw,
                 label='ROC curve of class {0} (area = {1:0.2f})'
                       ''.format(i, roc_auc[i]))

    plt.plot(fpr["micro"], tpr["micro"],
             label='micro-average ROC curve (area = {0:0.2f})'
                   ''.format(roc_auc["micro"]),
             color='deeppink', linestyle=':', linewidth=4)

    plt.plot(fpr["macro"], tpr["macro"],
             label='macro-average ROC curve (area = {0:0.2f})'
                   ''.format(roc_auc["macro"]),
             color='navy', linestyle=':', linewidth=4)

    plt.plot([0, 1], [0, 1], 'k--', lw=lw)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic for multi-class')
    plt.legend(loc="lower right")
    plt.show()
View Code
复制代码

 

这里需要注意的是:

当报错显示:predict_proba is not available when  probability=False时,
则为分类器probability参数被设置为False,导致不能计算预测概率,一些分类器(如SVM)中,predict_proba默认是禁用的
分类器中启用probability=True,例如SVM(probability=True)
调用:
复制代码
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from roc import plot_multiclass_roc

# 加载葡萄酒数据集
data = load_wine()
X = data.data
y = data.target

# 分割数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=421)
# 创建决策树分类器
clf = SVC(gamma='scale',kernel='linear',random_state=421,C=0.32,probability=True)
#probability=True与此调用有关
clf.fit(X_train, y_train)

plot_multiclass_roc(clf, X_test, y_test, n_classes=3)
View Code
复制代码

 

posted @   一眉师傅  阅读(13)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
点击右上角即可分享
微信分享提示