野路子码农系列(9)利用ONNX加速Pytorch模型推断
最近在做一个文本多分类的模型,非常常规的BERT+finetune的套路,考虑到运行成本,打算GPU训练后用CPU做推断。
在小破本上试了试,发现推断速度异常感人,尤其是序列长度增加之后,一条4-5秒不是梦。
于是只能寻找加速手段,早先听过很多人提到过ONNX,但从来没试过,于是就学习了一下,发现效果还挺不错的,手法其实也很简单,就是有几个小坑。
第1步 - 保存模型
首先得从torch中将模型导出成ONNX格式,可以在cross-validation的eval阶段进行这一步骤:
def eval_fn(data_loader, model, device): '此处省略其他代码' onnx_path = 'inference_model.onnx' # 指定保存路径 torch.onnx._export( model, # BERT fintune model (instance) (ids, mask, token_type_ids), # model的输入参数,装入tuple onnx_path, # 保存路径 opset_version=10, # 此处有坑,必须指定≥10,否则会报错 do_constant_folding=True, input_names=['ids', 'mask', 'token_type_ids'], # model输入参数的名称 output_names=['output'], export_params=True, dynamic_axes={ 'ids': {0: 'batch_size', 1: 'seq_length'}, # 0, 1分别代表axis 0和axis 1 'mask': {0: 'batch_size', 1: 'seq_length'}, 'token_type_ids': {0: 'batch_size', 1: 'seq_length'}, 'output': {0: 'batch_size', 1: 'seq_length'} } # 用于变长序列(比如dynamic padding)和可能改变batch size的情况 ) return '此处省略返回值'
这里需要注意的几个点:
- torch自带了导出ONNX的方法,直接用就行
- 你的模型可以有1个输入参数,也可以有多个,如果有多个,得装在tuple里
- 相应的input_names要与你的参数一一对应,放在list里
- opset_version建议设成10,默认不设的话可能会报错(ONNX export of Slice with dynamic inputs)
- 如果你在data loader里设置了collate func来进行dynamic padding的话(不同batch的文本长度可能不一样),一定要设置dynamic_axes,否则之后加载推断时会出错(因为它会要求你推断时输入的各个维度与你保存ONNX模型时的输入纬度完全一致)。
第2步 - 加载模型与推断
接下来是推断环节,首先别忘了用 pip install onnx 和 pip install onnxruntime 来安装必需的库,之后通过以下代码导入使用:
import onnxruntime as ort
接下来你可以照常写你的dataset和data loader,但需要注意的是,data loader返回的得是numpy.array,而不是torch.tensor(collate_fn里改改就行),否则报错伺候。
然后就是导入模型:
import onnxruntime as ort onnx_model_path = 'inference_model.onnx' session = ort.InferenceSession(onnx_model_path)
再把data loader的输出分别接入对应的三个参数就好了:
session.run(ids, mask, token_type_ids)
用%%timeit看一下运行时间(CPU):
4条长度为10的文本
torch:4.77s
torch+ONNX:39.7ms
4条长度为50的文本
torch:21.2s
torch+ONNX:246ms
差不多快了百倍有余,效果相当不错啦。