TVM量化代码解析
TVM量化代码解析
TVM量化,非常方便,即插即用。使用加入了伪量化后的pass,替代原来的pass,一个官方提供的示例:
def test_mul_rewrite():
"""a test case where rhs of mul is not constant"""
data=relay.var("data",shape=(1,16,64,64))
multiplier=relay.sigmoid(relay.var("data",shape=(1,16,1,1)))
conv=relay.nn.conv2d(data,relay.var("weight"),
kernel_size=(3,3),
padding=(1,1),
channels=16)
act=relay.nn.relu(data=conv)
quantize_and_build(act * multiplier)
pool=relay.nn.global_avg_pool2d(data=act)
quantize_and_build(act * pool)
入口就是函数:
def quantize_and_build(out):
f=relay.Function(relay.analysis.free_vars(out),out)
mod,params=testing.create_workload(f)
with relay.quantize.qconfig(skip_conv_layers=[]):
qmod=relay.quantize.quantize(mod,params)
relay.build(qmod,"llvm",params=params)
return qmod
调用relay.quantize.quantize函数,这个函数实在太长了,只放上主体部分。
1. mod=prerequisite_optimize(mod,params)
2. calibrate_pass=tvm.transform.module_pass(
calibrate(dataset),opt_level=1,
name="QuantizeCalibrate")
quant_passes=[partition(),
annotate(),
calibrate_pass]
if not current_qconfig().do_simulation:
quant_passes.append(realize())
quant_passes.append(_transform.FoldConstant())
quantize_seq=tvm.transform.Sequential(quant_passes)
with tvm.transform.PassContext(opt_level=3,
required_pass=["QuantizeAnnotate",
"QuantizeCalibrate",
"QuantizeRealize"]):
3. with quantize_context():
mod=quantize_seq(mod)
4. q_cfg=current_qconfig()
assert q_cfg.partition_conversions in ['disabled','enabled','fully_integral']
if q_cfg.partition_conversions != 'disabled':
quantized_dtypes={q_cfg.dtype_input,q_cfg.dtype_weight,q_cfg.dtype_activation}
ensure_fully_integral=q_cfg.partition_conversions == 'fully_integral'
return partition_conversions(mod,quantized_dtypes,ensure_fully_integral)
从代码中,可看到,TVM量化需要做的就是
l 标号1,图优化部分,具体做哪些图优化,就可自己选了,如算子融合,常量折叠。
l 标号2,整个量化的步骤,包括定义quant_passes,如果发现config设置,不需要伪量化,就是inference阶段了,就把realize加进去,否则,只需要annotate及calibrate,优化量化参数。
l 标号3,开始做量化了,将一个fp32的inference graph,转成int类型的inference graph,可参照第一张图。
l 标号4,把realize的graph,或者说对于一个op的前向推理的步骤,分成前中后三部分:
比如,conv2d,input_quantization -> input_quantization*weight_quantization(core function) -> ouput_dequantization,
每一个算子计算完后,都要dequant回去,很有可能某些算子没量化,还得用fp32。
最优解肯定是全部都量化掉,直接int32跑到底,TVM搞了个参数ensure_fully_integral,保证所有的算子都量化了。
参考链接:
https://blog.csdn.net/Artyze/article/details/108776522
https://www.freesion.com/article/3155559638/
https://discuss.tvm.apache.org/t/rfc-search-based-automated-quantization/5483
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· .NET10 - 预览版1新功能体验(一)