3、使用 TVMC Python 入门:TVM 的高级 API
本节将介绍针对 TVM 初学者设计的脚本工具。
开始前如果没有下载示例模型,需要先通过终端下载 resnet 模型:
cd myscripts/ mv resnet50-v2-7-frozen.onnx my_model.onnx
在myscripts目录下新建tvmcpythonintro.py
0、导入
from tvm.driver import tvmc
1、加载模型
将模型导入 TVMC。这一步将机器学习模型从支持的框架,转换为 TVM 的高级图形表示语言 —— Relay。这是为 TVM 中的所有模型统一起点。目前支持的框架:Keras、ONNX、TensorFlow、TFLite 和 PyTorch。
model = tvmc.load('my_model.onnx') # 第 1 步:加载
查看 Relay,可运行 model.summary()
。
所有框架都支持用 shape_dict 参数覆盖输入 shape。对于大多数框架,这是可选的;但对 PyTorch 是必需的,因为 TVM 无法自动搜索它。
推荐通过 netron 查看模型 input/shape_dict。打开模型后,单击第一个节点查看输入部分中的 name 和 shape。
2、编译
模型现在是用 Relay 表示的,下一步是将其编译到要运行的硬件(称为 target)。这个编译过程将模型从 Relay,翻译成目标机器可理解的底层语言。
编译模型需要一个 tvm.target 字符串。查看 文档 了解有关 tvm.targets 及其选项的更多信息。一些例子如下:
- cuda (英伟达 GPU)
- llvm (CPU)
- llvm -mcpu=cascadelake(英特尔 CPU)
(上面我理解的是编译到TVM运行时)
package = tvmc.compile(model, target="llvm") # 第 2 步:编译
编译完成后返回一个 package。
3、运行
编译后的 package 可在目标硬件上运行。设备输入选项有:CPU、Cuda、CL、Metal 和 Vulkan。
result = tvmc.run(package, device="cpu") # 第 3 步:运行
用 print(result)
打印结果。
第 1.5 步:调优【可选并推荐】
通过调优可进一步提高运行速度。此可选步骤用机器学习来查看模型(函数)中的每个操作,并找到一种更快的方法来运行它。这一步通过 cost 模型,以及对可能的 schedule 进行基准化来实现。
这里的 target 与编译过程用到的 target 是相同的。
tvmc.tune(model, target="llvm") # 第 1.5 步:可选 Tune
终端输出如下所示:
[Task 1/25] Current/Best: 0.00/ 0.00 GFLOPS | Progress: (0/400) | 0.00 s [Task 1/25] Current/Best: 2.72/ 17.48 GFLOPS | Progress: (2/400) | 4.41 s [Task 1/25] Current/Best: 8.50/ 17.48 GFLOPS | Progress: (4/400) | 6.54 s [Task 1/25] Current/Best: 14.82/ 17.48 GFLOPS | Progress: (6/400) | 8.78 s [Task 1/25] Current/Best: 25.60/ 25.60 GFLOPS | Progress: (8/400) | 10.33 s ....
出现的 UserWarnings 可忽略。调优会使最终结果运行更快,但调优过程会耗费几个小时的时间。
参阅下面的“保存调优结果”部分,若要应用结果,务必将调优结果传给编译。
保存调优结果
把调优结果保存在文件中,方便以后复用。
- 方法 1:
log_file = "hello.json" # 运行 tuning tvmc.tune(model, target="llvm", tuning_records=log_file) ... # 运行 tuning,然后复用 tuning 的结果 tvmc.tune(model, target="llvm",tuning_records=log_file)
- 方法 2:
# 运行 tuning tuning_records = tvmc.tune(model, target="llvm") ... # 运行 tuning,然后复用 tuning 的结果 tvmc.tune(model, target="llvm",tuning_records=tuning_records)
参考资料:
使用 TVMC Python 入门:TVM 的高级 API | Apache TVM 中文站 (hyper.ai)