量化训练及精度调优经验分享
本文提纲:
- fx 和 eager 两种量化训练方式介绍
- 量化训练的流程介绍:以 mmdet 的 yolov3 为例
- 常用的精度调优 debug 工具介绍
- 案例分析:模型精度调优经验分享
第一部分:fx 和 eager 两种量化训练方式介绍
首先介绍一下量化训练的原理。
上图为单个神经元的计算,计算形式是加权求和,再经过非线性激活后得到输出,这个输出又可以作为下一个神经元的输入继续运输,所以神经网络的基础运算是矩阵的乘法。如果神经元的计算全部采用 float32 的形式,模型的内存占用和数据搬运都会很占资源。如果用 int8 替换 float32,内存的搬运效率能提高 75%,充分展示了量化的有效性。由于两个 int8 相乘会超出 int8 的表示范围,为了防止溢出,累加器使用 int32 类型的,累加后的结果会再次 requantized 到 int8;
量化的目标就是在尽可能不影响模型精度的情况下降低模型的功耗,实现模型压缩效果,常见的量化方式有后量化训练 PTQ 和量化感知训练 QAT。
量化感知训练其实是一种伪量化的过程,即在训练过程中模拟浮点转定点的量化过程,数据虽然都是表示为 float32,但实际的值会间隔地受到量化参数的限制。具体方法是在某些 op 前插入伪量化节点(fake quantization nodes),伪量化节点有两个作用:
1.在训练时,用以统计流经该 op 的数据的最大最小值,便于在部署量化模型时对节点进行量化
2.伪量化节点参与模型训练的前向推理过程,因此会模型训练中导入了量化损失,但伪量化节点是不参与梯度更新过程的。
上图是模型学习量化损失的示意图, 正常的量化流程是 quantize->mul(int)->dequantize,而伪量化是对原先的 float 先 quantize 到 int,再 dequantize 到 float,这个步骤用于模拟量化过程中 round 操作所带来的误差,用这个误差再去进行前向运算。上图可以比较直观的表示引起误差的原因,从左到右数第 4 个黑点表示一个浮点数,quantize 后映射到 253,dequantize 后取到了第 5 个黑点,这就引起了误差。
地平线基于 PyTorch 开发的 horizon_plugin_pytorch 量化训练工具,同时支持 Eager 和 fx 两种模式。
eager 模式的使用方式建议参考用户手册 -4.2 量化感知训练章节(4.2.2。 快速上手中有完整的快速上手示例,各使用阶段注意事项建议参考 4.2.3。 使用指南)。fx 模式的相关 API 介绍请参考用户手册 -4.2.3.4.2。 主要接口参数说明章节
第二部分:量化训练的流程介绍:以 mmdet 的 yolov3 为例
QAT 流程介绍
准备好浮点模型,加载训好的浮点权重
model = build_detector(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
model.init_weights()# 加载config里的 init_cfg
设置 BPU 架构
set_march(March.BAYES)
算子融合(eager 模式需要,fx 可省略)
# qat: run fuse_module to fuse conv+bn/relu/add op
model.backbone.fuse_modules()
model.neck.fuse_modules()
model.bbox_head.fuse_modules()
设置量化配置
- 整个 model 使用默认的 qconfig
- 模型的输出,配置高精度输出
- det 模型 head 输出的 loss 损失函数的 qconfig 设置为 None
# qat: set qconfig for float model
model.qconfig = get_default_qat_qconfig()
# qat: set default_qat_out_qconfig for last conv
for m in model.bbox_head.convs_pred:
m.qconfig = get_default_qat_out_qconfig()
# qat: set None for loss qconfig, loss should be quantized
model.bbox_head.loss_cls.qconfig = None
model.bbox_head.loss_conf.qconfig = None
model.bbox_head.loss_xy.qconfig = None
model.bbox_head.loss_wh.qconfig = None
将浮点模型转换为 qat 模型(示例使用 eager 模式)
qat_model = prepare_qat(model)
qat_model.to(torch.device("cuda:1"))
开始 qat 训练
- 可以复用浮点的 train_detector,替换 model 即可
train_detector(
qat_model,
datasets,
cfg,
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
meta=meta)
qat 模型转定点(需要 load 训练好的 qat 模型权重)
quantized_model = convert(qat_model.eval())
deploy_model 和 example_input 准备
deploy_model = DeployModel(
quantized_model.backbone, quantized_model.neck,
quantized_model.bbox_head).to(torch.device("cuda:1"))
example_input = torch.randn(size=(24, 3, 320, 320), device=torch.device("cuda:1"))
Trace 模型构建静态 graph,进行编译
- eval()使 bn、dropout 等处于正确的状态
- 编译只能在 cpu 上做
- check_model 用于检查算子是否能全部跑在 bpu 上,建议提前检查
traced_model = torch.jit.trace(deploy_model.eval(), example_input)
traced_model.to(torch.device("cpu"))
example_input.to(torch.device("cpu"))
check_model(traced_model, example_input, advice=1)
compile_model(traced_model, [example_input], opt=0, hbm="model.hbm")
如果 qat 精度不达标,如何插入 calibration?
1. 准备好浮点模型,加载训好的浮点权重
2. 设置BPU架构
3. 算子融合(eager模式需要,fx可省略)
4. 设置model的量化配置
-----------------calib_model-------------------
calib_model = prepare_qat(float_model)
calib_model.eval() # 使bn、dropout等处于正确的状态
set_fake_quantize(calib_model, FakeQuantState.CALIBRATION) # 不进行伪量化操作,仅观测算子输入输出统计量,更新scale
#校准训练(可复用浮点的train_detector,替换model即可)
train_detector(
calib_model,
datasets,
cfg,
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
meta=meta)
#校准精度验证
calib_model.eval()
set_fake_quantize(calib_model, FakeQuantState.VALIDATION)
val(calib_model,val_dataloader,device)
-----------此时calib_model里的scale已经更新了-------------------------
qat_model = prepare_qat(float_model)
-----------qat_model加载calib训练好的模型权重,开始qat训练-----------------------------------------------
train_detector(
qat_model,
datasets,
cfg,
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
meta=meta)
伪量化节点(fake quantize)的三种状态:
- CALIBRATION 模式:即不进行伪量化操作,仅观测算子输入输出统计量,更新 scale
- QAT 模式:观测统计量并进行伪量化操作。
- VALIDATION 模式:不会观测统计量,仅进行伪量化操作。
以下常见误操作会导致一些异常现象:
- calibration 之前模型设置为 train()的状态,且未使用
set_fake_quantize
,等于是在跑 QAT 训练; - calibration 之前模型设置为 eval()的状态,且未使用
set_fake_quantize
,会导致 scale 一直处于初始状态,全为 1,calib 不起作用。 - calibration 之前模型设置为 eval()的状态,且正确使用了
set_fake_quantize
,但是在这之后又设置了一遍 model.eval(),这将导致 fake_quant 未处于训练状态,scale 一直处于初始状态,全为 1;
对 mobilenet_v2 模型做 qat 训练的设置
量化节点设置
关键代码:
from horizon_plugin_pytorch.quantization import QuantStub
self.quant = QuantStub(scale=1/128) # 一般 pyramid 输入的 Quant 层,需要手动设置 scale=1/128def fuse_modules(self):
x = self.quant(x)
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm
from horizon_plugin_pytorch.quantization import QuantStub
from ..builder import BACKBONES
from ..utils import InvertedResidual, make_divisible
import torch
@BACKBONES.register_module()
class MobileNetV2(BaseModule):
arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2],
[6, 64, 4, 2], [6, 96, 3, 1], [6, 160, 3, 2],
[6, 320, 1, 1]]
def __init__(self,
widen_factor=1.,
out_indices=(1, 2, 4, 7),
frozen_stages=-1,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6'),
norm_eval=False,
with_cp=False,
pretrained=None,
init_cfg=None):
super(MobileNetV2, self).__init__(init_cfg)
# qat: model start with Quantization node
# and set scale=1/128
self.quant = QuantStub(scale=1/128) # 一般pyramid输入的Quant层,需要手动设置scale=1/128
self.pretrained = pretrained
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be specified at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]
else:
raise TypeError('pretrained must be a str or None')
self.widen_factor = widen_factor
self.out_indices = out_indices
if not set(out_indices).issubset(set(range(0, 8))):
raise ValueError('out_indices must be a subset of range'
f'(0, 8). But received {out_indices}')
if frozen_stages not in range(-1, 8):
raise ValueError('frozen_stages must be in range(-1, 8). '
f'But received {frozen_stages}')
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.in_channels = make_divisible(32 * widen_factor, 8)
self.conv1 = ConvModule(
in_channels=3,
out_channels=self.in_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.layers = []
for i, layer_cfg in enumerate(self.arch_settings):
expand_ratio, channel, num_blocks, stride = layer_cfg
out_channels = make_divisible(channel * widen_factor, 8)
inverted_res_layer = self.make_layer(
out_channels=out_channels,
num_blocks=num_blocks,
stride=stride,
expand_ratio=expand_ratio)
layer_name = f'layer{i + 1}'
self.add_module(layer_name, inverted_res_layer)
self.layers.append(layer_name)
if widen_factor > 1.0:
self.out_channel = int(1280 * widen_factor)
else:
self.out_channel = 1280
layer = ConvModule(
in_channels=self.in_channels,
out_channels=self.out_channel,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.add_module('conv2', layer)
self.layers.append('conv2')
def make_layer(self, out_channels, num_blocks, stride, expand_ratio):
"""Stack InvertedResidual blocks to build a layer for MobileNetV2.
Args:
out_channels (int): out_channels of block.
num_blocks (int): number of blocks.
stride (int): stride of the first block. Default: 1
expand_ratio (int): Expand the number of channels of the
hidden layer in InvertedResidual by this ratio. Default: 6.
"""
layers = []
for i in range(num_blocks):
if i >= 1:
stride = 1
layers.append(
InvertedResidual(
self.in_channels,
out_channels,
mid_channels=int(round(self.in_channels * expand_ratio)),
stride=stride,
with_expand_conv=expand_ratio != 1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
with_cp=self.with_cp))
self.in_channels = out_channels
return nn.Sequential(*layers)
def _freeze_stages(self):
if self.frozen_stages >= 0:
for param in self.conv1.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
layer = getattr(self, f'layer{i}')
layer.eval()
for param in layer.parameters():
param.requires_grad = False
# qat: do fuse model
def fuse_modules(self):
self.conv1.fuse_modules()
for layer_name in self.layers:
layer = getattr(self, layer_name)
if hasattr(layer, "fuse_modules"):
layer.fuse_modules()
elif isinstance(layer, nn.Sequential):
for m in layer:
if hasattr(m, "fuse_modules"):
m.fuse_modules()
def forward(self, x):
"""Forward function."""
# qat: qat model start with QuantStub
x = self.quant(x)
x = self.conv1(x)
outs = []
for i, layer_name in enumerate(self.layers):
layer = getattr(self, layer_name)
x = layer(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer
frozen."""
super(MobileNetV2, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
算子融合
[7.5.5. 算子融合 — Horizon Open Explorer](https://developer.horizon.ai/api/v1/fileData/horizon_j5_open_explorer_cn_doc/plugin/source/advanced_content/op_fusion.html?highlight=算子融合 算子 融合#)
举个例子:mmcv/cnn/bricks/conv_module.py
class ConvModule(nn.Module):
...
# qat: fuse conv + bn/relu
def fuse_modules(self):
fuse_list = None
if self.with_norm:
if self.with_activation:
fuse_list = ["conv", self.norm_name, "activate"] # conv+bn+relu
else:
fuse_list = ["conv", self.norm_name] # conv+bn
else:
if self.with_activation:
fuse_list = ["conv", "activate"] # conv+relu
if fuse_list is not None:
torch.quantization.fuse_modules(
self,
fuse_list,
inplace=True,
fuser_func=quantization.fuse_known_modules,
)
eager 方案麻烦的是,基本每个模块都要手动去设置算子融合
反量化节点设置
mmdetection-master/mmdet/models/dense_heads/yolo_head.py
关键代码:
self.dequant = nn.ModuleList() # 不止1个反量化节点,用list包起来
self.dequant.append(DeQuantStub())
def fuse_modules(self):
pred_map = self.dequant[i](self.convs_pred[i](x))
class YOLOV3Head(BaseDenseHead, BBoxTestMixin):
def __init__(self,
num_classes,
in_channels,
out_channels=(1024, 512, 256),
anchor_generator=dict(
type='YOLOAnchorGenerator',
base_sizes=[[(116, 90), (156, 198), (373, 326)],
[(30, 61), (62, 45), (59, 119)],
[(10, 13), (16, 30), (33, 23)]],
strides=[32, 16, 8]),
bbox_coder=dict(type='YOLOBBoxCoder'),
featmap_strides=[32, 16, 8],
one_hot_smoother=0.,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
# qat
# act_cfg=dict(type='LeakyReLU', negative_slope=0.1),
act_cfg=dict(type='ReLU'),
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0),
loss_conf=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0),
loss_xy=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0),
loss_wh=dict(type='MSELoss', loss_weight=1.0),
train_cfg=None,
test_cfg=None,
init_cfg=dict(
type='Normal', std=0.01,
override=dict(name='convs_pred'))):
super(YOLOV3Head, self).__init__(init_cfg)
# Check params
assert (len(in_channels) == len(out_channels) == len(featmap_strides))
self.num_classes = num_classes
self.in_channels = in_channels
self.out_channels = out_channels
self.featmap_strides = featmap_strides
self.train_cfg = train_cfg
self.test_cfg = test_cfg
if self.train_cfg:
self.assigner = build_assigner(self.train_cfg.assigner)
if hasattr(self.train_cfg, 'sampler'):
sampler_cfg = self.train_cfg.sampler
else:
sampler_cfg = dict(type='PseudoSampler')
self.sampler = build_sampler(sampler_cfg, context=self)
self.fp16_enabled = False
self.one_hot_smoother = one_hot_smoother
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.bbox_coder = build_bbox_coder(bbox_coder)
self.prior_generator = build_prior_generator(anchor_generator)
self.loss_cls = build_loss(loss_cls)
self.loss_conf = build_loss(loss_conf)
self.loss_xy = build_loss(loss_xy)
self.loss_wh = build_loss(loss_wh)
self.num_base_priors = self.prior_generator.num_base_priors[0]
assert len(
self.prior_generator.num_base_priors) == len(featmap_strides)
self._init_layers()
def _init_layers(self):
self.convs_bridge = nn.ModuleList()
self.convs_pred = nn.ModuleList()
self.dequant = nn.ModuleList() # 不止1个反量化节点,用list包起来
for i in range(self.num_levels):
conv_bridge = ConvModule(
self.in_channels[i],
self.out_channels[i],
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
conv_pred = nn.Conv2d(self.out_channels[i],
self.num_base_priors * self.num_attrib, 1)
self.convs_bridge.append(conv_bridge)
self.convs_pred.append(conv_pred)
self.dequant.append(DeQuantStub())
def fuse_modules(self):
for m in self.convs_bridge:
m.fuse_modules()
def forward(self, feats):
"""Forward features from the upstream network.
Args:
feats (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
tuple[Tensor]: A tuple of multi-level predication map, each is a
4D-tensor of shape (batch_size, 5+num_classes, height, width).
"""
assert len(feats) == self.num_levels
pred_maps = []
for i in range(self.num_levels):
x = feats[i]
x = self.convs_bridge[i](x)
pred_map = self.dequant[i](self.convs_pred[i](x))
pred_maps.append(pred_map)
return tuple(pred_maps),
第三部分:常用的精度调优 debug 工具介绍
工具:集成接口、量化配置检查、模型可视化、相似度对比、统计量、分步量化、异构模型部署 device 检查
第四部分:模型精度调优分享
模型精度调优时常遇到的问题:
-
calib 模型的精度和 float 对齐,quantized 模型的精度损失较大
正常情况下,calib/qat 模型的精度和 quantized 模型的精度损失很小(1%), 如果偏差过大,可能是 calib/qat 的流程不对。
原因:calib 模型伪量化节点的状态不正确,导致 calib 阶段,测试的是 float 模型的精度,而 quantized 阶段,测试的是 calib 模型的精度,所以精度损失本质上还是量化精度的损失。
如何避免:
- 正确设置 calib 训练和评测时的伪量化节点状态。
- 让客户在 calib 的基础上,做 qat, 评测 qat 模型的精度。(客户的数据量大,qat 时间太长,一直没有选择 qat,导致这个问题被暴露出来了)
如何设置正确的 calib 伪量化节点的状态?(fx 和 eager 都是一样的)
http://model.aidi.hobot.cc/api/docs/horizon_plugin_pytorch/latest/html/user_guide/calibration.html
#加载浮点模型权重
model.load_state_dict(torch.load("output/toy_experiments/model_moderate_best_soft_float_131892.pth"))
set_march(March.BAYES)
#校准配置
calib_model = prepare_qat_fx(
model,
{"":default_calib_8bit_fake_quant_qconfig,
"module_name":
...
}).to(device)
calib_model.to(device)
#校准需要全程开启eval()状态
calib_model.eval()
#校准的训练阶段,设置伪量化节点模式为 CALIBRATION
set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)
train(cfg, calib_model, device, distributed)
#校准的评测阶段,设置伪量化节点的模式为 VALIDATION
set_fake_quantize(calib_model, FakeQuantState.VALIDATION)
#加载校准的模型权重
calib_model.load_state_dict(torch.load("output/toy_experiments/model_moderate_best_soft_calib_118633.pth"))
#测试校准的精度
run_test(cfg, calib_model, vis=args.vis, eval_score_iou=args.eval_score_iou, eval_all_depths=args.eval_all_depths) # 11.8650
注意:16 行的 train 在评测时,也要设置 FakeQuantState.VALIDATION,不然 scale 不生效,评测的指标也不对
常见问题:
- 数据校准之前模型设置为 train()的状态,且未使用
set_fake_quantize
,等于 caib 阶段是在跑 QAT 训练; - 校准的评测阶段,未设置伪量化节点的模式为 VALIDATION, 实际评测的是 float 模型;
总结 2: 如果做 calib,一定要仔细检查伪量化节点状态和模型状态是否正确,避免不符合预期的结果
2.当量化精度损失超过大,如何调优?
- 使用 model_profiler() 这个集成接口,生成压缩包。
- 检查是否配置高精度输出、是否存在未融合的算子、是否共享 op、是否算子分布过大 int8 兜不住?
- 注意:使用 debug 集成接口时,要保证浮点模型训练到位,并传入真实数据
3.多任务模型的精度调优建议
- qat 调优策略和常规模型一样,ptq+qat
- 如果只有一个 head 精度有损失,可以固定其他部分,单独使用这个 head 的数据做 calib
4.calib 和 qat 流程的正确衔接
calib:
#加载浮点模型权重
model.load_state_dict(torch.load("output/toy_experiments/model_moderate_best_soft_float_131892.pth"))
set_march(March.BAYES)
#校准配置
calib_model = prepare_qat_fx(
model,
{"":default_calib_8bit_fake_quant_qconfig,
"module_name":
...
}).to(device)
calib_model.to(device)
#校准需要全程开启eval()状态
calib_model.eval()
#校准的训练阶段,设置伪量化节点模式为 CALIBRATION
set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)
train(cfg, calib_model, device, distributed)
#校准的评测阶段,设置伪量化节点的模式为 VALIDATION
set_fake_quantize(calib_model, FakeQuantState.VALIDATION)
#加载校准的模型权重
calib_model.load_state_dict(torch.load("output/toy_experiments/model_moderate_best_soft_calib_118633.pth"))
#测试校准的精度
run_test(cfg, calib_model, vis=args.vis, eval_score_iou=args.eval_score_iou, eval_all_depths=args.eval_all_depths) # 11.8650
qat:
set_march(March.BAYES)
qat_model = prepare_qat_fx(
model,
{"":default_qat_8bit_fake_quant_qconfig,
"module_name":
'''
}).to(device)
qat_model.to(device)
#加载校准模型权重
qat_model.load_state_dict(torch.load("output/toy_experiments/model_moderate_best_soft_calib_118633.pth"))
#训练阶段,保证模型处于model.train()状态,这样伪量化节点也处于qat模式
train(cfg, qat_model, device, distributed)
5.检查 conv 高精度输出
方式 1:查看 qconfig_info.txt,重点关注 DeQuantStub 附近的 conv 是不是 float32 输出
qconfig_info.txt
方式 2:打印 qat_model 的最后一层,查看该层是否有 (activation_post_process): FakeQuantize
高精度的 conv:
(1): ConvModule2d(
(0): Conv2d(
64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
(weight_fake_quant): FakeQuantize(
fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, scale=tensor([1., 1., 1.]), zero_point=tensor([0, 0, 0])
(activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
)
)
)
)
int8 的 conv
(0): ConvModule2d(
(0): ConvReLU2d(
64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
(weight_fake_quant): FakeQuantize(
fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, scale=tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), zero_point=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
(activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
)
(activation_post_process): FakeQuantize(
fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0])
(activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
)
)
6.检查共享 op
打开 qconfig_info.txt,后面标有(n)的就是共享的
特殊情况:layernorm 在 QAT 阶段是多个小量化算子拼接而成,module 的重复调用,也会产生大量 op 共享的问题
解决办法: 将 layernorm 替换为 batchnorm,测试了 float 精度,没有下降。
7.检查未融合的算子
打开 qconfig_info.txt,全局搜 BatchNorm2d 和 ReLU,如果前面有 conv,那就是没做算子融合
可以融合的算子:
- conv+bn
- conv+relu
- conv+add
- conv+bn+relu
- conv+bn+add
- conv+bn+relu+add
8.检查数据分布特别大的算子
打开 float 模型的统计量分布,一般是 model0_statistic.txt
有两个表,第一个表是按模型结构排列的;第二个表是按数据分布范围排列的
拖到第二个表,看前几行是那些 op
可以看到很多 conv 的分布很异常,使用的是 int8 量化
解决办法:
- 检查这些 conv 后面是否有 bn,添加 bn 后,数据能收敛一些
- 如果结构上已经加了 bn,数据分布还大,可以配置 int16 量化
- int16 调这两个接口,default_qat_16bit_fake_quant_qconfig 和 default_calib_16bit_fake_quant_qconfig
- 中间算子的写法和高精度输出类似 model.xx.qconfig = default_qat_16bit_fake_quant_qconfig ()