使用Torchmetrics快速进行验证指标的计算
TorchMetrics可以为我们提供一种简单、干净、高效的方式来处理验证指标。TorchMetrics提供了许多现成的指标实现,如Accuracy, Dice, F1 Score, Recall, MAE等等,几乎最常见的指标都可以在里面找到。torchmetrics目前已经包好了80+任务评价指标。
TorchMetrics安装也非常简单,只需要PyPI安装最新版本:
pip install torchmetrics
基本流程介绍
在训练时我们都是使用微批次训练,对于TorchMetrics也是一样的,在一个批次前向传递完成后将目标值Y和预测值Y_PRED传递给torchmetrics的度量对象,度量对象会计算批次指标并保存它(在其内部被称为state)。
当所有的批次完成时(也就是训练的一个Epoch完成),我们就可以从度量对象返回最终结果(这是对所有批计算的结果)。这里的每个度量对象都是从metric类继承,它包含了4个关键方法:
- metric.forward(pred,target) - 更新度量状态并返回当前批次上计算的度量结果。如果您愿意,也可以使用metric(pred, target),没有区别。
- metric.update(pred,target) - 与forward相同,但是不会返回计算结果,相当于是只将结果存入了state。如果不需要在当前批处理上计算出的度量结果,则优先使用这个方法,因为他不计算最终结果速度会很快。
- metric.compute() - 返回在所有批次上计算的最终结果。也就是说其实forward相当于是update+compute。
- metric.reset() - 重置状态,以便为下一个验证阶段做好准备。
也就是说:在我们训练的当前批次,获得了模型的输出后可以forward或update(建议使用update)。在批次完成后,调用compute以获取最终结果。最后,在验证轮次(Epoch)或者启用新的轮次进行训练时您调用reset重置状态指标
完整文章:
https://avoid.overfit.cn/post/bdedfe4229e04da49049c4e7d56152d1