PyTorch 中,评估模型性能的常用指标以及计算方式
在 PyTorch 中,评估模型性能的指标会根据不同的任务类型(如分类、回归、目标检测等)而有所不同,以下是常见任务对应的性能指标:
分类任务
1. 准确率(Accuracy)
- 定义:分类正确的样本数占总样本数的比例。它是最常用的分类指标之一,直观地反映了模型整体的分类能力。
- 公式:$Accuracy=\frac{TP + TN}{TP + TN+FP + FN}$,其中 $TP$(真正例)是实际为正类且被模型预测为正类的样本数,$TN$(真反例)是实际为反类且被模型预测为反类的样本数,$FP$(假正例)是实际为反类但被模型预测为正类的样本数,$FN$(假反例)是实际为正类但被模型预测为反类的样本数。
- 代码示例
import torch
from sklearn.metrics import accuracy_score
# 模拟真实标签和预测标签
y_true = torch.tensor([0, 1, 1, 0])
y_pred = torch.tensor([0, 1, 0, 0])
accuracy = accuracy_score(y_true, y_pred)
print(f"Accuracy: {accuracy}")
2. 精确率(Precision)
- 定义:预测为正类的样本中,实际为正类的比例。它衡量了模型预测为正类的准确性。
- 公式:$Precision=\frac{TP}{TP + FP}$
- 代码示例
from sklearn.metrics import precision_score
precision = precision_score(y_true, y_pred)
print(f"Precision: {precision}")
3. 召回率(Recall)
- 定义:实际为正类的样本中,被模型预测为正类的比例。它衡量了模型找到所有正类样本的能力。
- 公式:$Recall=\frac{TP}{TP + FN}$
- 代码示例
from sklearn.metrics import recall_score
recall = recall_score(y_true, y_pred)
print(f"Recall: {recall}")
4. F1 值(F1-score)
- 定义:精确率和召回率的调和平均数,用于综合衡量模型的性能,在精确率和召回率之间取得平衡。
- 公式:$F1 = 2\times\frac{Precision\times Recall}{Precision + Recall}$
- 代码示例
from sklearn.metrics import f1_score
f1 = f1_score(y_true, y_pred)
print(f"F1-score: {f1}")
5. ROC 曲线与 AUC
- 定义:ROC 曲线(Receiver Operating Characteristic curve)是以假正率(FPR)为横轴,真正率(TPR)为纵轴绘制的曲线,反映了模型在不同阈值下的分类性能。AUC(Area Under the Curve)是 ROC 曲线下的面积,取值范围在 0 到 1 之间,AUC 越接近 1,模型性能越好。
- 代码示例
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
# 模拟模型输出的概率
y_score = torch.tensor([0.1, 0.8, 0.3, 0.2])
fpr, tpr, thresholds = roc_curve(y_true, y_score)
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
plt.show()
回归任务
1. 均方误差(Mean Squared Error, MSE)
- 定义:预测值与真实值之间误差的平方的平均值,反映了预测值与真实值的平均偏离程度。
- 公式:$MSE=\frac{1}{n}\sum_{i = 1}{n}(y_i-\hat{y}_i)2$,其中 $y_i$ 是真实值,$\hat{y}_i$ 是预测值,$n$ 是样本数量。
- 代码示例
import torch.nn as nn
# 模拟真实值和预测值
y_true = torch.tensor([1.0, 2.0, 3.0])
y_pred = torch.tensor([1.2, 1.8, 3.1])
mse_loss = nn.MSELoss()
mse = mse_loss(y_pred, y_true)
print(f"MSE: {mse.item()}")
2. 均方根误差(Root Mean Squared Error, RMSE)
- 定义:均方误差的平方根,与原始数据具有相同的量纲,更直观地反映了预测值与真实值的平均误差大小。
- 公式:$RMSE=\sqrt{\frac{1}{n}\sum_{i = 1}{n}(y_i-\hat{y}_i)2}$
- 代码示例
import torch
rmse = torch.sqrt(mse)
print(f"RMSE: {rmse.item()}")
3. 平均绝对误差(Mean Absolute Error, MAE)
- 定义:预测值与真实值之间误差的绝对值的平均值,它对异常值的敏感性相对较低。
- 公式:$MAE=\frac{1}{n}\sum_{i = 1}^{n}|y_i-\hat{y}_i|$
- 代码示例
mae_loss = nn.L1Loss()
mae = mae_loss(y_pred, y_true)
print(f"MAE: {mae.item()}")
4. 决定系数(Coefficient of Determination, $R^2$)
- 定义:衡量回归模型拟合优度的指标,表示模型能够解释的因变量的变异比例,取值范围在 $(-\infty, 1]$ 之间,越接近 1 表示模型拟合效果越好。
- 公式:$R^2 = 1-\frac{\sum_{i = 1}{n}(y_i-\hat{y}_i)2}{\sum_{i = 1}{n}(y_i-\bar{y})2}$,其中 $\bar{y}$ 是真实值的平均值。
- 代码示例
from sklearn.metrics import r2_score
r2 = r2_score(y_true, y_pred)
print(f"R^2: {r2}")
目标检测任务
1. 平均精度均值(Mean Average Precision, mAP)
- 定义:在目标检测任务中,对于每个类别,计算其平均精度(AP),然后将所有类别的 AP 求平均得到 mAP。AP 是根据不同召回率下的精确率计算得到的,mAP 是衡量目标检测模型性能的主要指标。
- 代码示例:在实际应用中,通常使用第三方库(如
torchvision
)中的评估函数来计算 mAP。
from torchvision import ops
# 模拟真实框和预测框
boxes_true = torch.tensor([[10, 10, 50, 50], [20, 20, 60, 60]])
labels_true = torch.tensor([0, 1])
scores_pred = torch.tensor([0.9, 0.8])
boxes_pred = torch.tensor([[12, 12, 52, 52], [22, 22, 62, 62]])
labels_pred = torch.tensor([0, 1])
# 计算 mAP
ap = ops.boxes.box_iou(boxes_true, boxes_pred)
# 这里只是简单示例,实际计算 mAP 更复杂,需要结合不同的召回率和精确率
2. 交并比(Intersection over Union, IoU)
- 定义:预测框与真实框的交集面积与并集面积的比值,用于衡量预测框与真实框的重合程度。在目标检测中,通常会设定一个 IoU 阈值(如 0.5),当预测框与真实框的 IoU 大于该阈值时,认为该预测框是正确的。
- 公式:$IoU=\frac{Area\ of\ Intersection}{Area\ of\ Union}$
- 代码示例
from torchvision.ops import box_iou
iou = box_iou(boxes_true, boxes_pred)
print(f"IoU: {iou}")
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律