roc曲线含义
为了描述问题方便,只讨论二分类问题.对于分类问题的分类结果,可以得到如下"分类结果混淆矩阵".
通过"分类结果混淆矩阵"给出"真正率"(TPR)和"假正率"(FPR)的定义.
真正率:
\[\begin{equation}\label{tpr}
TPR=\frac{TP}{TP+FN}
\end{equation}
\]
假正率
\[\begin{equation}\label{fpr}
FPR=\frac{FP}{TN+FP}
\end{equation}
\]
对于二分类算法,算法得到的结果是给每个样本属于正例还是反例给一个概率. 下表是个例子.
编号 | 阳性概率 | 阴性概率 |
---|---|---|
1 | 0.1 | 0.9 |
2 | 0.4 | 0.6 |
3 | 0.35 | 0.45 |
4 | 0.8 | 0.2 |
上表这个数据集有4个样本.这里只是为了说明问题方便,样本量较小,实际数据集中样本肯定不止4个.如果要求roc曲线,有如下步骤:
- 求每个样本"正例"概率."正例"是人为定义的,这里定义阳性为"正例",那么阴性就是"反例".当然也能定义阴性为"正例",那么阳性就为"反例".定义谁是"正例"不会影响roc曲线的最终结果.这里定义阳性为"正例",那么4个样本"正例"的概率是0.1,0.4,0.35,0.8 已知这4个样本真实分类情况是阴阴阳阳,那么在定义阳性为"正例"的情况下,真实的样本情况target记为0,0,1,1
- 将"正例"的概率按从大到小的顺序排列.得到数组thresholds
- 用数组thresholds的每个元素(记为threshold)与"正例"的概率作比较,若"正例"的概率大于threshold,该样本的分类结果记为1,若"正例"的概率小于threshold,该样本的分类结果记为0. 再考虑已知的真实的分类情况就可以得若干个"分类结果混淆矩阵", "分类结果混淆矩阵"的个数与样本数相等.对于示例数据集,数组thresholds的长度为4,所以该步骤能得到4个分类结果如下表所示.
threshold | 分类结果 |
---|---|
0.8 | 0, 0, 0, 1 |
0.4 | 0, 1, 0, 1 |
0.35 | 0, 1, 1, 1 |
0.1 | 1, 1, 1, 1 |
根据该表可以进一步求得4个"分类结果混淆矩阵".
4. 由"分类结果混淆矩阵",再根据公式\ref{tpr}和\ref{fpr}计算出各个"分类结果混淆矩阵"对应的TPR,FPR的值.
5. 以FPR为横坐标,以TPR为纵坐标,在直角坐标系上标出各点,再将各点用直线连接起来,就得到ROC曲线图.
1-4步可以用如下代码表示:
# coding: utf-8
import copy
def cal_roc(target, origin, threshold):
predict = list()
for index, o in enumerate(origin):
if origin[index] >= threshold:
predict.append(1)
else:
predict.append(0)
print(predict)
tp = fn = fp = tn = 0
for index, t in enumerate(target):
if target[index] == 1 and predict[index] == 1:
tp += 1
elif target[index] == 1 and predict[index] == 0:
fn += 1
elif target[index] == 0 and predict[index] == 1:
fp += 1
else:
tn += 1
return tp / (tp + fn), fp / (tn + fp)
def get_roc(target, scores):
origin = copy.deepcopy(scores)
scores.sort(reverse=True)
fpr_list = list()
tpr_list = list()
thresholds = list()
for index, s in enumerate(scores):
thresholds.append(s)
tpr, fpr = cal_roc(target, origin, s)
fpr_list.append(fpr)
tpr_list.append(tpr)
print(fpr_list)
print(tpr_list)
print(thresholds)
if __name__ == '__main__':
target = [0, 0, 1, 1]
scores = [0.1, 0.4, 0.35, 0.8]
get_roc(target, scores)
参考资料
周志华,机器学习,清华大学出版社,p30-34,2016.
sklearn官方文档