python 根据分类结果求ROC,AUC

#!/usr/bin/python3
# _*_coding:utf-8 _*_

# @Time       :2021/2/21 23:14
# @Author    :jory.d
# @File       :roc_auc.py
# @Software    :PyCharm
# @Desc: 绘制多分类的ROC AUC曲线

import matplotlib as mpl

# mpl.use('Agg')  # Agg   TkAgg
import matplotlib.pyplot as plt
import numpy as np
from sklearn import metrics
from sklearn.preprocessing import label_binarize
import random
from pprint import pprint

np.set_printoptions(precision=2)


def get_other_metrics(label_names, y_trues, y_probs):
    """
    计算分类指标, P, R, F1
    """
    assert type(label_names) is list
    assert type(y_trues) is list
    assert type(y_probs) is list
    assert len(y_trues) == len(y_probs)
    y_true = np.array(y_trues)
    y_prob = np.array(y_probs)
    y_pred = np.argmax(y_prob, axis=-1)

    Precision = metrics.precision_score(y_true, y_pred, average=None)
    Recall = metrics.recall_score(y_true, y_pred, average=None)
    F1_Score = metrics.f1_score(y_true, y_pred, average=None)
    return Precision, Recall, F1_Score


def get_cmap(N):
    '''
    Returns a function that maps each index in 0, 1,.. . N-1 to a distinct
    RGB color.
    '''
    import matplotlib.cm as cmx
    import matplotlib.colors as colors
    color_norm = colors.Normalize(vmin=0, vmax=N - 1)
    scalar_map = cmx.ScalarMappable(norm=color_norm, cmap='hsv')

    def map_index_to_rgb_color(index):
        return scalar_map.to_rgba(index)

    return map_index_to_rgb_color


def create_roc_auc(label_names, y_trues, y_probs, png_save_path, is_show=True):
    """
    使用sklearn得api计算ROC,并绘制曲线
    :param label_names:
    :param y_trues:
    :param y_probs:
    :param png_save_path:
    :param is_show:
    :return:
    """
    assert type(label_names) is list
    assert type(y_trues) is list
    assert type(y_probs) is list
    assert len(y_trues) == len(y_probs)

    labels = list(label_names)
    n_classes = len(label_names)
    y_true = np.array(y_trues)
    y_prob = np.array(y_probs)
    y_true_one_hot = label_binarize(y_true, np.arange(n_classes))  # 装换成类似二进制的编码
    # Compute ROC curve and ROC area for each class
    fpr, tpr, roc_auc = {}, {}, {}
    for i in range(n_classes):
        fpr[i], tpr[i], thres = metrics.roc_curve(y_true_one_hot[:, i], y_prob[:, i])
        roc_auc[i] = metrics.auc(fpr[i], tpr[i])

    pprint(fpr)
    pprint(tpr)
    print('AUC: {}'.format(roc_auc))
    mpl.rcParams['font.sans-serif'] = u'DejaVu Sans'  # DejaVu Sans   SimHei
    mpl.rcParams['axes.unicode_minus'] = False

    fig = plt.figure()
    color = ('b', 'g', 'r', 'c', 'm', 'y', 'k', 'w')
    cmap = get_cmap(n_classes)
    # Plot of a ROC curve for a specific class
    for i in range(n_classes):
        # FPR就是横坐标,TPR就是纵坐标
        _col = cmap(i) if n_classes > len(color) else color[i]
        plt.plot(fpr[i], tpr[i], c=_col, lw=2, alpha=0.7, label=u'%d AUC=%.3f' % (i, roc_auc[i]))

    plt.plot((0, 1), (0, 1), c='#808080', lw=1, ls='--', alpha=0.7)
    plt.xlim((-0.01, 1.02))
    plt.ylim((-0.01, 1.02))
    plt.xticks(np.arange(0, 1.1, 0.1))
    plt.yticks(np.arange(0, 1.1, 0.1))
    plt.xlabel('False Positive Rate', fontsize=13)
    plt.ylabel('True Positive Rate', fontsize=13)
    plt.grid(b=True, ls=':')
    plt.legend(loc='lower right', fancybox=True, framealpha=0.8, fontsize=12)
    plt.title(u'ROC curve', fontsize=17)
    plt.savefig(png_save_path, format='png')
    if is_show:
        plt.show()

    return fig


def create_roc_self(label_names, y_trues, y_probs, png_save_path, is_show=True):
    """
    python 实现计算tpr, fpr; 同时统计多个阈值下每个class的指标,用于后处理时选择最优阈值
    :param label_names:
    :param y_trues:
    :param y_probs:
    :param png_save_path:
    :param is_show:
    :return:
    """
    assert type(label_names) is list
    assert type(y_trues) is list
    assert type(y_probs) is list
    assert len(y_trues) == len(y_probs)

    n_classes = len(label_names)
    y_trues = np.array(y_trues)
    y_probs = np.array(y_probs)
    bs = y_probs.shape[0]
    y_trues_one_hot = label_binarize(y_trues, np.arange(n_classes))  # 装换成类似二进制的编码
    print(y_trues)
    print(y_trues_one_hot)
    tpr_dict, fpr_dict = {}, {}
    thresh = [i / 10 for i in range(1, 11)]
    # y_pred = np.argmax(y_probs, axis=1)  # [n,]
    for i in range(n_classes):
        tpr_dict[i] = []
        fpr_dict[i] = []
        y_true = y_trues_one_hot[:, i]  # [n,]
        y_pred_prob = y_probs[:, i]
        # 计算下0.1~1.0这每个阈值下的tpr, fpr
        for th in thresh:
            # tpr = tp/(tp+fn), fpr = fp/(tn+fp)
            # y_pred_prob = np.array([y_probs[i, y_pred[i]] for i in range(bs)])  # [n,]
            y_pred2 = np.where(y_pred_prob >= th, 1, 0)
            tp = np.sum(y_pred2[y_true == 1] == 1)
            fn = np.sum(y_pred2[y_true == 1] == 0)
            fp = np.sum(y_pred2[y_true == 0] == 1)
            tn = np.sum(y_pred2[y_true == 0] == 0)
            tpr = tp / (tp + fn + 1e-5)
            fpr = fp / (tn + fp + 1e-5)
            print(f'thres={th}, tpr={tpr}, fpr={fpr}')
            tpr_dict[i].append(round(tpr, 2))
            fpr_dict[i].append(round(fpr, 2))

    pprint('tpr: {}'.format(tpr_dict))
    pprint('fpr: {}'.format(fpr_dict))

    cols = 2
    rows = round(n_classes / cols)
    fig = plt.figure(figsize=(12, 12), dpi=150)
    fig.suptitle('per class tpr and fpr', fontsize='xx-large')
    for r in range(rows):
        for c in range(cols):
            id = r * cols + c
            if id > n_classes - 1: break
            ax = fig.add_subplot(rows, cols, id + 1)
            x = thresh
            ax.plot(x, tpr_dict[id], c='b', label='tpr')
            ax.plot(x, fpr_dict[id], c='r', label='fpr')
            ax.set_xlabel('thres', fontsize='x-large')
            ax.set_ylabel('tpr_fpr', fontsize='x-large')
            plt.xticks(np.arange(0, 1.1, 0.2))
            plt.yticks(np.arange(0, 1.1, 0.2))

    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, loc='lower right', fontsize='x-large')
    plt.savefig(png_save_path, format='png')
    if is_show:
        plt.show()

    return fig


# 计算每个class的 fpr, tpr

np.random.seed(888)
if __name__ == '__main__':
    labels = ['A', 'B', 'C']
    batch_size = 100
    # 真值和预测值
    y_true = np.random.randint(0, len(labels), [batch_size]).tolist()
    y_prob = np.random.random([batch_size, len(labels)]).tolist()
    # _ = create_roc_auc(labels, y_true, y_prob, './ss1.png')
    _ = create_roc_self(labels, y_true, y_prob, './ss2.png')

 

posted @ 2021-03-10 22:47  dangxusheng  阅读(384)  评论(0编辑  收藏  举报