关于AUC

AUC是推荐系统常用的线下评价指标,其全称是Area Under the Curve。这里的Curve一般是指ROC(受试者操作曲线,Receiver operating characteristic),所以我们所说的AUC一般是指AUROC。
本文将沿分类阈值 -> 混淆矩阵 -> ROC -> AUC的路线梳理对AUC的理解。

分类阈值->混淆矩阵

在做二分类任务时,模型一般会对每个样本输出一个分值s(有时这个分值也表示样本是正例的概率)。
在这个分值区间里,设置一个阈值t,就可以把在阈值之上的预测为正例,阈值之下的预测为负例。
根据样本的真实标签和模型的预测标签,可以分为四种情况,统计四种情况的样本个数,就可以得到一个混淆矩阵(confusion matrix):

正样本 负样本
预测为正 TP(真正例) FP(假正例)
预测为负 FN(假负例) TN(真负例)

比如下图,正样本和负样本在分值区间上是两个正态分布,阈值t可以把每个分布切分成两块。

左右调整阈值t的大小,这四个值也会随之变化。于是,每个阈值t都会得到一个混淆矩阵。

假如我们要比较N个模型,每个模型设置K个阈值,那么就有N*K个混淆矩阵。用这么多混淆矩阵来比较这些模型的效果显然不现实。

混淆矩阵->TPR/FPR

所以接下来,根据混淆矩阵里的四个值,定义两个指标:

  • True Positive Rate (TPR,真正例率) = \(\frac{TP}{TP+FN}\)
  • False Positive Rate (FPR,假正例率) = \(\frac{FP}{FP+TN}\)

对于真正例率TPR,分子是得分>t里面正样本的数目,分母是总的正样本数目。
对于假正例率FPR,分子是得分>t里面负样本的数目,分母是总的负样本数目。

一个好的模型,肯定是TPR越大越好,FPR越小越好。

分类阈值->TPR/FPR->ROC

那么当我们调整阈值t时,因为混淆矩阵里四个值的变化,这两个指标也就会随之变化。
当t变小时,因为FN变小,TPR会变大;因为TN变小,FPR也会变大。反之,当t变大时,TPR和FPR都会变小。
那么如果我们取各种不同的阈值t,并以TPR和FPR作为坐标轴,画出各个t值对应的点,就可以得到这么一个图:

这个图有这么几个特点

  • 因为好的模型肯定是TPR大而FPR小,所以这些点都在左上的部分
  • 因为t变大时,TPR和FPR都会变小,所以这些点从右上角到左下角是单调向下
  • 当t很小时,TN和FN是0,那么TPR和FPR都是1。当t很大时,TP和FP都是0,那么TPR和FPR都是0。
    所以这些点是从右上角的(1,1)到左下角的(0,0)

连接这些点,可以得到一个上凸的曲线,就是ROC曲线。

随机模型的ROC曲线

随机的模型不区分正负样本,所以在得分>t的样本里,正负样本的比例应该与总体正负样本的比例相同。无论t是多少。
\(\frac{TP}{FP}=\frac{P}{N}=\frac{TP+FN}{TN+FP}\)。因此\(\frac{TP}{TP+FN}=\frac{FP}{TN+FP}\),即TPR\(=\)FPR。
所以随机模型的ROC曲线是一条连接(0,0)和(1,1)的直线。

我们也可以画出正负例在随机模型的分值区间上的分布,比如下图的两种情况,正负例在分值区间上同样均匀分布,或者正负例在分值区间上有相同的正态分布,TPR和FPR都会随着t的变化始终相等。

ROC->AUC

有了ROC曲线,比较多个模型的效果,就是比较多条ROC曲线。
曲线越往左上凸,模型效果越好。

但二维的曲线比较起来也还是不方便。

所以接下来定义ROC曲线下的面积为AUROC(Area Under the ROC Curve),大部分时候简写为AUC。
好的模型更向左上凸,所以曲线下面积也就更大,AUC也就更大。
所以就可以通过比较N个模型的AUC的大小来方便地比较它们的分类效果。

另外,随机模型的AUC是对角线下的面积,即0.5。所以任何合理的模型的AUC都应该大于0.5。

AUC的另一种解释

AUC除了以上定义为ROC曲线下面积,还有另一种常用的解释:随机选取一个正样本和一个负样本,模型给正样本的分值高于负样本分值的概率。

这个解释与以上用ROC曲线下面积来做的定义之间的关系,见StackExchange

AUC的计算

根据上面对AUC的解释,可以实现一种最简单的\(O(n^2)\)复杂度的计算方法(\(n\)是样本数)。
即取所有的正样本和负样本组成的二元组,共P*N个。统计这些二元组中,模型给正样本分值大于负样本分值的二元组所占的比例。

点击查看代码
import numpy as np

y    = np.array([1,   0,   0,   0,   1,    0,   1,    0,    0,   1  ])
pred = np.array([0.9, 0.4, 0.3, 0.9, 0.35, 0.6, 0.65, 0.32, 0.8, 0.7])

def auc():
    pos = pred[y==1]
    neg = pred[y==0]
    correct = 0
    for i in pos:
      for j in neg:
        if i > j:
          correct += 1
        elif i == j:
          correct += 0.5
    return correct / (len(pos) * len(neg))
print(auc())

from sklearn import metrics
fpr, tpr, thresholds = metrics.roc_curve(y, pred, pos_label=1)
print(metrics.auc(fpr, tpr))

对样本先排序可以把复杂度降低到\(O(n\log(n))\),见谈auc的计算

AUC的适用与不适用

适用

  • 仅关心相对序,不考虑分值大小时
  • 与阈值无关时
  • 要求对正负样本比例不敏感时,如对负样本下采样

不适用

  • 关心分值大小时
  • 某类正样本或负样本更重要时,如email spam classification

参考

END

posted @ 2024-03-02 21:25  maxuewei2  阅读(44)  评论(0编辑  收藏  举报
这是一段经过10次base64加密的密文:Vm0wd2VFMUhSblJXYTFwT1ZsWndUMVV3WkRSV2JHeDBZM3BHYUZKc1ZqTldiVEZIVjBaS2RHVkVRbFZXYkhCUVdWWlZlRll5U2tWVWJHUk9ZV3hhVFZkWGRHRlRNazE1Vkd0YVlWSnRhRzlVVnpGdlZWWmFjMWt6YUZOTlJGWjZWakkxVDJGc1NuTmpTRUpXWWxoU00xWkdXbUZqYkhCRlZXeHdWMkV5ZHpCV2FrbzBZekpHYzFOWVpGaGlSa3BoV1ZSS2IxSkdWbk5YYlVacVlraENSbFpYZUhkV01rVjZVV3BhVjJKVVFYaFdha1poWkVaT2MySkdTbWxXUjNoWFZtMTBWMWxXVWtkV1dHaFlZbGhTV0ZSV1pGTk5SbFowVFZoa1ZXSkdiRFJWTW5oelZqSktTRlJZYUZkV1JYQk1WV3BHVDJNeFduUmlSazVzWWxob2IxWXhXbE5TTWxGNFZXdGthbEp0YUhOVk1GVXhWMFpTV0dSSFJsTk5WMUo1VmpKek5WWXdNVVZTYTNCV1ZqTlJkMVpxUm1GU2JHUnpWV3hhVjFKV2NIbFhhMVpoVkRKTmVWTnJhR2hTYkVwVVZGUktiMWRXV25KWGJVWmFWbXN4TlZaSE5VOWhiRXBZVlcxb1ZtSkhhRlJXTVZwWFl6RldkVlJzYUZOaVNFRjNWa1phYjFReFdYaFRia3BxVW01Q1YxWnVjRUpOVmxweFVWaG9hbFpyV25oV1IzaFhWakpLVjFOc2JGZGlXRUpJVmxSR2ExZEdUbkphUmxwcFZqTm9kbFpHVWtOVE1EVlhWMjVTVGxaR1NuQlVWM1J6VGtaYVdFNVZPV2hpUlhCWldWVmFRMVl5Um5KVGJXaFhZa1p3ZWxsNlJtdGtSa3B5VGxaT2FXRXdjRmxXTVZwWFlUQXhTRkpyWkZoaVJscFVXVlJPUTFsV1duTlhhM1JUVW14c05WUldWakJXTVZwelkwaHNWMVl6YUZoWlZscGhVbXhrY21GR2FHbFNNVVYzVjFaU1MxVXhUa2RUYmtwaFVteGFjRlZzVWxkbGJHUllaRWRHYWsxRVZraFdNalZQVm0xRmVWVnVRbFZXYlZFd1ZqRmFZVkl5UmtsVWJGcE9ZVE5DU1ZkVVFtOVVNVnAwVTJ0a2FsSXlhR0ZVVlZwM1ZrWmFjMWRyZEd0V2F6VXdXbFZhVDJGV1pFaGFSRTVYWVRGd1dGbHFTa3BsVms1eVdrWm9XRkl4U2xGV2FrSnZVVEZzVjFkdVRtRlNlbXhYVlcweE5GWXhXWGxrUkVKVlRXdHdWMWt3Vm05WGF6RkhZMFJPV2xaV1ZqUmFSV1JIVW1zeFYyRkhiRk5pYTBvMVZteG9kMU14VlhoVWEyUlhZbXR3V0ZsclZURmpSbHB4VkcwNVZsSnRVbGhXVjNSM1ZERmFWVlpzYUZoaE1taE1WMVphUzFKc1RuVlNiRlpYVm10d1dWWkdWbUZXYlZaSVVtdHNZVkp0VWxSWmEyaERVMFphU0dWSGRHbE5WMUl3VlRKMGIyRkdUa2RqUmxwWFlsaG9NMVl3V2xOa1IxWkdUMWQwVTFaR1dscFhiRlpyWXpGYVIxTnNXbXBTVjJoWVdXeG9VMk5XY0ZaYVJrcHNWbXR3V2xsVldtOVhSa2w0VTI1b1YxWXpVbGhWZWtaaFl6RldjMXBHYUdoTk1VcFZWbGN3ZUZVeFpFZFhXR3hzVWpOU1ZsUlhkSGRUUm10M1lVYzVWMkpHYkRaWlZWSlBWakZKZWxScVVtRlNiSEJVVmpGa1IxSXlSa2RhUjJ4VVVsVndNbFpxUm05a01VbDVVbGhvV0ZkSGFGaFpiWGhoVmpGc2MyRkdUbXBOVjNoV1ZXMDFhMVpzV25OalJFSlZWbGRvZGxadGMzaGpiR1J5WVVaa1YyVnNXbFZYVmxKSFV6RktjMVJ1VmxOaVJuQndWakJhUzJJeFduTlZhMlJYVFZWc05GWnRlSE5aVmtweVYyeGtWMkV4U2tOVWJFVTVVRkU5UFE9PQ==