关于多任务学习MTL的实现步骤记录

拿任务来讲解:

假设我目前已经有了一个目标检测功能,检测物体a,现在我想判断这个a是不是真实场景下的,需要对整图再加个判断,即二分类。这样的需求其实就是典型的多任务学习,即检测+分类

拿到这样的任务,目前经过实际操作已知如下有效的训练方法。

  1. 首先,训练数据要保证类别均衡,包括两个方面,检测就不用说了,最好是各个类别均衡,其次是分类方面,检测类别的实拍和非实拍也要均衡
  2. 然后数据预处理,假设batch是32,那么我们需要从分类数据中取出16张,从检测数据中取出16张,然后用各自的预处理方式进行预处理,此处的预处理指的是样本增强,最后组合成batch-32,送到模型中
  3. 计算损失,此时需要在backbone上加入一个分类头,以yolov5为例,这个分类头可以加到三个输出中的任何一个当中都可以,一般情况下加到最后一个输出上
  4. 在损失计算上,根据分类和检测的图片位置,分别计算相应的分类损失和检测损失,然后相加后反向传播,至此,多任务的整体训练流程就结束

 

posted @ 2022-09-27 11:57  海_纳百川  阅读(94)  评论(0编辑  收藏  举报
本站总访问量