【评价指标】ROC曲线与AUC
一、前置知识
真阳性(TPR):正样本被正确分类个数与所有正样本的总数的比值
假阳性(FPR):负样本被错误分类个数与所有负样本的总数的比值
其中,TP
表示正确分类的正样本,TN
表示正确分类的负样本,FN
表示错误分类的负样本,FN
表示错误分类的负样本
二、ROC基本内容
[!tip]
在机器学习与深度学习模型中,模型输出阈值
的设置会直接影响到结果的判断
在二分类中一般采用Sigmoid
激活函数输出结果。如图所示,x轴表示最后模型的输出概率,y轴表示经过sigmoid
的变换后的输出结果,对应的TPR、FPR如表示所示,可以发现不同阈值的结果不同。
TPR | FPR | |
---|---|---|
左图 | 0.5 | 1 |
右图 | 0.75 | 0.75 |
不同阈值的设置会造成大量的参数和结果,为此,ROC曲线以FPR
为x轴,TPR
为y轴绘制二维图像,在ROC图上的特殊线:
y = 1
:表示正样本全部判断正确,并且x值越小模型的效果越好;x = 0
:表示负样本全部没有误判,并且y值越大模型的效果越好;y = x
:表示TPR = FPR,(1,1)点表示正样本全部判断正确,负样本全部错判;
三、AUC
- 基本理解
AUC(Area Under the Curve)表示的是在ROC曲线与坐标轴围成的面积,表示在FPR从0到1的过程中TPR的累积值
x = 0
:表示在当前阈值下,只有正样本的得分大于阈值;y = 1
:表示在当前阈值下,所有正样本的得分大于阈值;x = 1
:表示在当前阈值下,所有样本的得分都大于阈值;
因此,我们可以观察到,当阈值慢慢下降时,有:
① 逐渐有正样本的得分大于阈值,TPR上升,但此时还没有负样本被误判,因此FPR为0,表现为x = 0
;
② 当负样本的阈值开始大于阈值时,TRP不变,FPR增加,此时开始对AUC进行计算,因为AUC表示的是从0到1的过程中TPR的累积值
[!important]
FPR开始增加时,不同的TPR对AUC的累积值不同,比如两个模型分别在
TPR=0.5
和TPR=0.75
时负样本分数开始大于阈值,第二个模型在最终AUC的计算得分要优于第一个模型
③ 最好的情况就是所有的正样本得分大于负样本,此时当TPR为1时,FPR才开始增加,这时的TPR累计值为1.
- 更深层理解
AUC可以描述的是随机抽取一个正样本和负样本,正样本得分大于负样本得分的概率:
因为AUC在计算的过程中也间接计算了正负样本的排序,当出现错误顺序(误判负样本为正样本时),AUC的累积值减少,如图所示,模型而(蓝色线)的AUC累积值要大于模型一(红色线),这是由于在阈值降低的过程中,出现错误顺序的阈值更低。同时,FPR的增加,可以与坐标轴围成矩形区域,该面积可以表示为样本中得分比当前负样本得分高的所有正负样本对。
四、代码实现
def calculate_roc_acc(y, y_prop):
y = np.array(y)
y_prop = np.array(y_prop)
sorted_indexs = np.argsort(-y_prop) # 默认从小到大将索引排序,所以从大到小需要负号
y = y[sorted_indexs]
y_prop = y_prop[sorted_indexs]
posNum = sum(y)
negNum = len(y) - posNum
# 记录不同阈值下的TPR和FPR
TPRS = [0]
FPRS = [0]
# 阈值初始化
thresholds = y_prop[np.argsort(-y_prop)]
for threshold in thresholds:
TP = np.sum((y_prop >= threshold) & (y == 1)) # 预测为正且实际为正
FP = np.sum((y_prop >= threshold) & (y == 0)) # 预测为正但实际为负
TPRS.append(TP / posNum)
FPRS.append(FP / negNum)
AUC = 0
for i in range(1, len(FPRS)):
AUC += (FPRS[i] - FPRS[i - 1]) * (TPRS[i] + TPRS[i - 1]) / 2
return FPRS, TPRS, AUC
五、可视化代码
def plot_roc_curve(FPR, TPR, auc):
plt.rcParams["axes.linewidth"] = 1.8
plt.rcParams["axes.labelsize"] = 12
plt.rcParams["xtick.minor.visible"] = True
plt.rcParams["ytick.minor.visible"] = True
plt.rcParams["xtick.direction"] = "in"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["xtick.labelsize"] = 12
plt.rcParams["ytick.labelsize"] = 12
plt.rcParams["xtick.top"] = False
plt.rcParams["ytick.right"] = False
# plt.rcParams['font.sans-serif'] = ['SimHei']
plt.figure(figsize=(6, 6))
plt.plot(FPR, TPR, label=f"AUC = {auc:.2f}", color='blue')
plt.plot([0, 1], [0, 1], 'r--')
plt.title("ROC Curve")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.legend(loc="lower right")
plt.grid()
plt.show()
本文作者:九年义务漏网鲨鱼
本文链接:https://www.cnblogs.com/DLShark/p/18577765
版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步