roc曲线含义

为了描述问题方便,只讨论二分类问题.对于分类问题的分类结果,可以得到如下"分类结果混淆矩阵".

通过"分类结果混淆矩阵"给出"真正率"(TPR)和"假正率"(FPR)的定义.
真正率:

\[\begin{equation}\label{tpr} TPR=\frac{TP}{TP+FN} \end{equation} \]

假正率

\[\begin{equation}\label{fpr} FPR=\frac{FP}{TN+FP} \end{equation} \]

对于二分类算法,算法得到的结果是给每个样本属于正例还是反例给一个概率. 下表是个例子.

编号 阳性概率 阴性概率
1 0.1 0.9
2 0.4 0.6
3 0.35 0.45
4 0.8 0.2

上表这个数据集有4个样本.这里只是为了说明问题方便,样本量较小,实际数据集中样本肯定不止4个.如果要求roc曲线,有如下步骤:

  1. 求每个样本"正例"概率."正例"是人为定义的,这里定义阳性为"正例",那么阴性就是"反例".当然也能定义阴性为"正例",那么阳性就为"反例".定义谁是"正例"不会影响roc曲线的最终结果.这里定义阳性为"正例",那么4个样本"正例"的概率是0.1,0.4,0.35,0.8 已知这4个样本真实分类情况是阴阴阳阳,那么在定义阳性为"正例"的情况下,真实的样本情况target记为0,0,1,1
  2. 将"正例"的概率按从大到小的顺序排列.得到数组thresholds
  3. 用数组thresholds的每个元素(记为threshold)与"正例"的概率作比较,若"正例"的概率大于threshold,该样本的分类结果记为1,若"正例"的概率小于threshold,该样本的分类结果记为0. 再考虑已知的真实的分类情况就可以得若干个"分类结果混淆矩阵", "分类结果混淆矩阵"的个数与样本数相等.对于示例数据集,数组thresholds的长度为4,所以该步骤能得到4个分类结果如下表所示.
threshold 分类结果
0.8 0, 0, 0, 1
0.4 0, 1, 0, 1
0.35 0, 1, 1, 1
0.1 1, 1, 1, 1

根据该表可以进一步求得4个"分类结果混淆矩阵".
4. 由"分类结果混淆矩阵",再根据公式\ref{tpr}和\ref{fpr}计算出各个"分类结果混淆矩阵"对应的TPR,FPR的值.
5. 以FPR为横坐标,以TPR为纵坐标,在直角坐标系上标出各点,再将各点用直线连接起来,就得到ROC曲线图.
1-4步可以用如下代码表示:

# coding: utf-8
import copy


def cal_roc(target, origin, threshold):
    predict = list()
    for index, o in enumerate(origin):
        if origin[index] >= threshold:
            predict.append(1)
        else:
            predict.append(0)
    print(predict)
    tp = fn = fp = tn = 0
    for index, t in enumerate(target):
        if target[index] == 1 and predict[index] == 1:
            tp += 1
        elif target[index] == 1 and predict[index] == 0:
            fn += 1
        elif target[index] == 0 and predict[index] == 1:
            fp += 1
        else:
            tn += 1
    return tp / (tp + fn), fp / (tn + fp)


def get_roc(target, scores):
    origin = copy.deepcopy(scores)
    scores.sort(reverse=True)
    fpr_list = list()
    tpr_list = list()
    thresholds = list()
    for index, s in enumerate(scores):
        thresholds.append(s)
        tpr, fpr = cal_roc(target, origin, s)
        fpr_list.append(fpr)
        tpr_list.append(tpr)
    print(fpr_list)
    print(tpr_list)
    print(thresholds)


if __name__ == '__main__':
    target = [0, 0, 1, 1]
    scores = [0.1, 0.4, 0.35, 0.8]
    get_roc(target, scores)

参考资料

周志华,机器学习,清华大学出版社,p30-34,2016.
sklearn官方文档

posted on 2022-04-06 12:54  荷楠仁  阅读(481)  评论(0编辑  收藏  举报

导航