目标检测量化总结
前言
最近一段时间在搞模型量化(之前量化基础为0),基本上查到了90%以上的成熟量化方案,QAT的方案真的非常不成熟,基本没有开源好用的方案。赛灵思挺成熟但仅针对自己的框架,修改代价太大了。阿里的框架不成熟,至少我在看代码的时候,他还在Fix-Bug。ONNX挺成熟,但使用人数基本没有,其作为IR工具,很少有人拿他来训练。。。。量化资料虽然多,但基本都是跑一个分类模型,至于检测的量化少之又少。
目前状态
环境:
- MMDetectionV2.15,已重构
- MQBenchV0.3,修改部分代码,修复部分BUG
- 后端Torch--V1.9.1
- 后端Tensorrt--V8.2
- 后端ONNX-ONNXRuntime--V1.19
简单试验YOLOX-Nano
- FP32:mAP17.5%
- QAT(未加载PQT):直接训练无法收敛、clip-grad收敛到较大loss无法下降。mAP无
- QAT(PQT)(无Augment):mAP12.2%
- QAT(PQT)(Augment):mAP18.4%
- QAT(INT8):mAP18.4%
更新1:
YOLOX-S
- FP32:mAP40.3%
- QAT:mAP39.7%
- QAT(INT8):mAP39.7%
未完全达到fp32的精度,YOLOX-S/Tiny都量化感知训练精度相比fp32误差在0.5以内
试验的方式和阿里加速团队基本一致,从试验结果来看整体流程较为完整。
量化理论
量化的理论较为简单(前向推理加速未涉及):
- \(r\) Float32 data
- \(q\) Quant data
- \(S\) Scale
- \(Z\) ZeroPoint
基本所有的论文都是围绕以上四个公式进行的,未对具体论文进行总结,仅看代码给出的几个简单的例子:
- Scale稳定性
由于scale的变化幅度过大会对训练造成严重的震荡,所有较低频率的修改scale才能促进训练。
def scale_smooth(tensor: torch.Tensor):
log2t = torch.log2(tensor)
log2t = (torch.round(log2t) - log2t).detach() + log2t
return 2 ** log2t
- Scale和ZP的计算
这部分是优化最多且最有效的,因为一个好的初始化至关重要
- 都是比较简单,代码一看便明了
- 最大最小值
- 均值方差
- 直方图
- Clip
- ......
- Learn/Fixed
Scale/ZP是训练还是固定?
训练的情况非常花时间,因为量化节点已经得插入上百个,如果再加上训练,速度慢的可怜🥺!而且得长时间的tune,收敛缓慢。当然效果肯定比固定好🔥。
固定的情况会节约大量时间,但精度略低于训练的情况。
量化方案
QAT:训练量化
PQT:训练后量化
目前主流还是使用PQT,比如Tensorrt、NCNN、MNN、ONNX。。。。基本前向推理的框架都支持训练后量化。
少数使用QAT(仅在PQT精度较低的情况会使用),比如pytorch、ONNX、TF。。。基本只有训练框架才支持。
QAT-PyTorch
原始的训练方案 采用手动插入节点,目前已经完全废弃,这里不做介绍。
当前主推的方案 基于torch.FX模块进行,这里简单介绍一下流程
torch.FX是DAG(双向-有向图)结构,和ONNX的Graph类似,但是核心不同。
- FX.Graph是由Node节点构成,每个Node节点表示一个Operate,Node是一个双向指针。在前向计算的时候遍历每个Node,Node调用Operate,这个操作可能是函数、类成员、类等。
- ONNX.Graph也是Node构成的,但是其使用list存储,只有读取某个Node才知道users和producer,而且权重等参数由另一个数据结构initializer存储
以下是QAT-pytorch的简易流程
注意: 其实FX模块就和ONNX一样,都是一个IR部件!!如何构建一个FX模块是难点(本人仅理解大概流程,未做具体分析)
To-Torch
由于使用了MQBench进行操作,无法直接使用官方的FX-convert进行转化。
由于使用了FX组件,也无法使用原始的convert进行转化。
考虑到转换的模型比较简单,比如检测YOLOX、YOLOV5,分类Resnet、MobileNet、ShuffleNet,所以没必要专门写一套MQBench和torch.convert的转换工具。
本人直接手写了一个类进行转换,转换采用torch原始的方式(将FX模块解析出来),YOLOX精度已对齐。
- Torch的前向推理quantized op(比如QConv、QLinear...)只支持非对称量化(torch.quint8),但是Quantize和DeQuantizae是支持对称量化(torch.qint8),在实际的模型转化中只能使用非对称量化进行。
- function得用module替换
- 需要重写torch的pqt代码
注意: 采用此方案坑还是比较多的,建议还是写个mqbench到torch.fx的中间工具。
To-ONNX
已完成,待写文档
To-TRT
已完成,待写文档
To-NCNN
已完成,待写文档
-------------------------------------------
个性签名:衣带渐宽终不悔,为伊消得人憔悴!
如果觉得这篇文章对你有小小的帮助的话,记得关注再下的公众号,同时在右下角点个“推荐”哦,博主在此感谢!