如何将onnx稳定的转换为tensorflow, 甚至转换为tflite(float32/int8)
做模型部署边缘设备的时候,我们经常会遇到特定格式的要求。但常见的onnx2tf很多时候都不能满足我们的要求。因此,记录一下我的操作过程。
- 1. 环境:(linux18.04)
-
# Name Version Build Channel _libgcc_mutex 0.1 main defaults _openmp_mutex 5.1 1_gnu defaults absl-py 0.15.0 pypi_0 pypi addict 2.4.0 pypi_0 pypi altgraph 0.17.4 pypi_0 pypi array-record 0.4.0 pypi_0 pypi astunparse 1.6.3 pypi_0 pypi beautifulsoup4 4.11.2 pypi_0 pypi ca-certificates 2024.3.11 h06a4308_0 defaults cachetools 5.3.3 pypi_0 pypi certifi 2024.2.2 pypi_0 pypi charset-normalizer 3.3.2 pypi_0 pypi click 8.1.7 pypi_0 pypi colorama 0.4.6 pypi_0 pypi defusedxml 0.7.1 pypi_0 pypi dnspython 2.6.1 pypi_0 pypi editdistance 0.8.1 pypi_0 pypi etils 1.3.0 pypi_0 pypi fast-ctc-decode 0.3.6 pypi_0 pypi filelock 3.14.0 pypi_0 pypi flatbuffers 1.12 pypi_0 pypi fsspec 2024.5.0 pypi_0 pypi future 1.0.0 pypi_0 pypi gast 0.4.0 pypi_0 pypi google-auth 2.29.0 pypi_0 pypi google-auth-oauthlib 0.4.6 pypi_0 pypi google-pasta 0.2.0 pypi_0 pypi googleapis-common-protos 1.63.0 pypi_0 pypi grpcio 1.34.1 pypi_0 pypi h5py 3.1.0 pypi_0 pypi huggingface-hub 0.23.2 pypi_0 pypi hyperopt 0.1.2 pypi_0 pypi idna 3.7 pypi_0 pypi imageio 2.34.1 pypi_0 pypi importlib-metadata 7.1.0 pypi_0 pypi importlib-resources 6.4.0 pypi_0 pypi joblib 1.4.2 pypi_0 pypi jstyleson 0.0.2 pypi_0 pypi keras-nightly 2.5.0.dev2021032900 pypi_0 pypi keras-preprocessing 1.1.2 pypi_0 pypi libedit 3.1.20230828 h5eee18b_0 defaults libffi 3.2.1 hf484d3e_1007 defaults libgcc-ng 11.2.0 h1234567_1 defaults libgomp 11.2.0 h1234567_1 defaults libstdcxx-ng 11.2.0 h1234567_1 defaults markdown 3.6 pypi_0 pypi markupsafe 2.1.5 pypi_0 pypi ncurses 6.4 h6a678d5_0 defaults networkx 2.8.8 pypi_0 pypi nibabel 5.1.0 pypi_0 pypi nltk 3.8.1 pypi_0 pypi numpy 1.19.5 pypi_0 pypi oauthlib 3.2.2 pypi_0 pypi onnx 1.13.0 pypi_0 pypi opencv-python 4.5.5.64 pypi_0 pypi openssl 1.1.1w h7f8727e_0 defaults openvino 2021.4.2 pypi_0 pypi openvino-dev 2021.4.2 pypi_0 pypi openvino-telemetry 2024.1.0 pypi_0 pypi openvino2tensorflow 1.34.0 pypi_0 pypi opt-einsum 3.3.0 pypi_0 pypi packaging 24.0 pypi_0 pypi pandas 1.1.5 pypi_0 pypi parasail 1.3.4 pypi_0 pypi pillow 9.4.0 pypi_0 pypi pip 24.0 py38h06a4308_0 defaults progress 1.6 pypi_0 pypi promise 2.3 pypi_0 pypi protobuf 3.20.3 pypi_0 pypi psutil 5.9.8 pypi_0 pypi py-cpuinfo 9.0.0 pypi_0 pypi pyasn1 0.6.0 pypi_0 pypi pyasn1-modules 0.4.0 pypi_0 pypi pydicom 2.4.4 pypi_0 pypi pyinstaller 6.7.0 pypi_0 pypi pyinstaller-hooks-contrib 2024.6 pypi_0 pypi pymongo 4.7.2 pypi_0 pypi python 3.8.0 h0371630_2 defaults python-dateutil 2.9.0.post0 pypi_0 pypi pytz 2024.1 pypi_0 pypi pywavelets 1.4.1 pypi_0 pypi pyyaml 6.0.1 pypi_0 pypi rawpy 0.21.0 pypi_0 pypi readline 7.0 h7b6447c_5 defaults regex 2024.5.15 pypi_0 pypi requests 2.32.3 pypi_0 pypi requests-oauthlib 2.0.0 pypi_0 pypi rsa 4.9 pypi_0 pypi scikit-image 0.19.3 pypi_0 pypi scikit-learn 1.3.2 pypi_0 pypi scipy 1.5.4 pypi_0 pypi sentencepiece 0.2.0 pypi_0 pypi setuptools 69.5.1 py38h06a4308_0 defaults shapely 2.0.4 pypi_0 pypi six 1.15.0 pypi_0 pypi soupsieve 2.5 pypi_0 pypi sqlite 3.33.0 h62c20be_0 defaults tensorboard 2.11.2 pypi_0 pypi tensorboard-data-server 0.6.1 pypi_0 pypi tensorboard-plugin-wit 1.8.1 pypi_0 pypi tensorflow 2.5.3 pypi_0 pypi tensorflow-datasets 4.9.2 pypi_0 pypi tensorflow-estimator 2.5.0 pypi_0 pypi tensorflow-metadata 1.14.0 pypi_0 pypi termcolor 1.1.0 pypi_0 pypi texttable 1.6.7 pypi_0 pypi threadpoolctl 3.5.0 pypi_0 pypi tifffile 2023.7.10 pypi_0 pypi tk 8.6.14 h39e8969_0 defaults tokenizers 0.19.1 pypi_0 pypi toml 0.10.2 pypi_0 pypi torch 1.12.1 pypi_0 pypi torchvision 0.13.1 pypi_0 pypi tqdm 4.66.4 pypi_0 pypi typing-extensions 3.7.4.3 pypi_0 pypi urllib3 2.2.1 pypi_0 pypi werkzeug 3.0.3 pypi_0 pypi wheel 0.43.0 py38h06a4308_0 defaults wrapt 1.12.1 pypi_0 pypi xz 5.4.6 h5eee18b_1 defaults yamlloader 1.4.1 pypi_0 pypi zipp 3.19.0 pypi_0 pypi zlib 1.2.13 h5eee18b_1 defaults
- 2. 具体代码:(下面是int8量化)
#!/usr/bin/env python """ a command line tool to format onnx model from pytorch-onnx to tflite model """ import random import os import tensorflow as tf import glob import cv2 import numpy as np from tqdm import tqdm import argparse from pathlib import Path import shutil from typing import List def parse_args(): parser = argparse.ArgumentParser( description="Formatting PyTorch models to TensorFlow models") parser.add_argument( "-i", "--input_onnx", type=str, help="an onnx file form pytorch model") parser.add_argument( "-s", "--shape", type=str, help="input image size (height, width)") parser.add_argument( "-o", "--output_dir", type=str, default="./", help="model output dir") parser.add_argument( "-t", "--tflite_file", type=str, help="output tflite file name") parser.add_argument( "-d", "--dataset", type=str, help="represent dataset") parser.add_argument( "-n", "--num_present_images", type=int, default=100, help="number of represent images for tflite quantization", ) args = parser.parse_args() return args def convert_onnx2tensorflow(args): modify_xml_func = None if args.modify and args.modify_model == "yolox": modify_xml_func = mo_yolox_ov_xml return onnx2tensorflow(args, modify_xml_func) def onnx2tensorflow(args, modify_xml_func=None): print(f"CWD:{os.getcwd()}") output_dir = args.output_dir onnx_model = args.input_onnx h, w = eval(args.shape) input_shape = (1, 3, h, w) ov_dir = os.path.join(output_dir, "ov") shutil.rmtree(ov_dir, ignore_errors=True) ov_cmd = ( f"mo \ --input_model {onnx_model} \ --input_shape '{input_shape}' \ --output_dir {ov_dir} \ --progress \ --data_type FP32" ) print(ov_cmd) assert os.system(ov_cmd) == 0, "failed in converting onnx to openvino" ov_xml = os.path.join(ov_dir, f"{Path(onnx_model).stem}.xml") # add changes for certain models if needed if modify_xml_func: modify_xml_func(ov_xml) tf_model_dir = os.path.join(output_dir, "hwc") shutil.rmtree(tf_model_dir, ignore_errors=True) ov2tf_cmd = ( f"openvino2tensorflow \ --model_path {ov_xml} \ --model_output_path {tf_model_dir} \ --output_no_quant_float32_tflite \ --output_saved_model" ) print(ov2tf_cmd) assert os.system(ov2tf_cmd) == 0, \ "failed in converting openvino to tensorflow" return tf_model_dir def get_represent_images(path: str, num_present_images: int) -> List[str]: direc = Path(path) files = list(direc.rglob("*.jpg")) if not len(files) > 0: files = list(direc.rglob("*.png")) if not len(files) > 0: files = list(direc.rglob("*.JPEG")) if not len(files) > 0: raise TypeError("unrecognised img file type") files = random.sample(files, num_present_images) return files def tflite_quantize(args, tf_model_dir): """ convert a tensorflow model to tflite model with represent data :param args: related command line args :param tf_model_dir: where the tensorflow model saved :return: tflite model in current working dir """ # assert args.tflite_file.split(".")[-1] == "tflite" # if os.path.exists(args.tflite_file): # os.remove(args.tflite_file) files = get_represent_images(args.dataset, args.num_present_images) h, w = eval(args.shape) def representative_dataset(): for file in tqdm(files): # read images in RGB format # assume that the images were trained in RGB format # img = cv2.imread(str(file), cv2.IMREAD_GRAYSCALE)[..., ::-1] img = cv2.imread(str(file))[..., ::-1] img = cv2.resize(img, (w, h)) # img = np.expand_dims(img, axis=-1) img = ((img - 127.5) / 127.5).astype(np.float32) img = img[None, ...] yield [img] converter = tf.lite.TFLiteConverter.from_saved_model(tf_model_dir) converter.optimizations = [tf.lite.Optimize.DEFAULT] # 8bits weight quantization converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.inference_input_type = tf.int8 converter.inference_output_type = tf.int8 converter.experimental_new_quantizer = False # it must be false if TF version is not 2.4.1 converter.representative_dataset = representative_dataset tflite_model = converter.convert() save_tflite_file = args.input_onnx.replace("onnx", "tflite") save_tflite_path = os.path.join( args.output_dir, os.path.basename(save_tflite_file), ) with open(save_tflite_path, 'wb') as f: f.write(tflite_model) return save_tflite_path def mo_yolox_ov_xml(ov_xml): """ special modifications for yolox model """ with open(ov_xml) as f: xml = f.read() with open(ov_xml, "w") as f: f.write(xml.replace('<data axis="2"/>', '<data axis="1"/>')) if __name__ == "__main__": onnx_args = parse_args() tensorflow_model_dir = onnx2tensorflow(onnx_args) tflite_quantize(onnx_args, tensorflow_model_dir) """ torch-onnx2tflite.py -i yolox-tiny.onnx \ -s "(320, 320)" \ -d yolox/datasets/COCO/val2017 <-t yox_tiny.tflite [alt] >\ """
- 3. 其他:
- 接着,我尝试使用pyinstaller将这个工具固化,生成一个exe来使用,但是似乎并不如意。因为它的工作机制是在cmd命令行输入代码,但是openvino2tensorflow需要添加到环境变量后才能使用,因此,如果不添加这个环境变量,生成的exe也没有什么用处。因此,这个exe的准备似乎没有用处。(我还没有找到解决方法,如果您有解决方法,请告诉我,不甚感激。)