如何将onnx稳定的转换为tensorflow, 甚至转换为tflite(float32/int8)

做模型部署边缘设备的时候,我们经常会遇到特定格式的要求。但常见的onnx2tf很多时候都不能满足我们的要求。因此,记录一下我的操作过程。

  • 1. 环境:(linux18.04)
  • # Name                    Version                   Build  Channel
    _libgcc_mutex             0.1                        main    defaults
    _openmp_mutex             5.1                       1_gnu    defaults
    absl-py                   0.15.0                   pypi_0    pypi
    addict                    2.4.0                    pypi_0    pypi
    altgraph                  0.17.4                   pypi_0    pypi
    array-record              0.4.0                    pypi_0    pypi
    astunparse                1.6.3                    pypi_0    pypi
    beautifulsoup4            4.11.2                   pypi_0    pypi
    ca-certificates           2024.3.11            h06a4308_0    defaults
    cachetools                5.3.3                    pypi_0    pypi
    certifi                   2024.2.2                 pypi_0    pypi
    charset-normalizer        3.3.2                    pypi_0    pypi
    click                     8.1.7                    pypi_0    pypi
    colorama                  0.4.6                    pypi_0    pypi
    defusedxml                0.7.1                    pypi_0    pypi
    dnspython                 2.6.1                    pypi_0    pypi
    editdistance              0.8.1                    pypi_0    pypi
    etils                     1.3.0                    pypi_0    pypi
    fast-ctc-decode           0.3.6                    pypi_0    pypi
    filelock                  3.14.0                   pypi_0    pypi
    flatbuffers               1.12                     pypi_0    pypi
    fsspec                    2024.5.0                 pypi_0    pypi
    future                    1.0.0                    pypi_0    pypi
    gast                      0.4.0                    pypi_0    pypi
    google-auth               2.29.0                   pypi_0    pypi
    google-auth-oauthlib      0.4.6                    pypi_0    pypi
    google-pasta              0.2.0                    pypi_0    pypi
    googleapis-common-protos  1.63.0                   pypi_0    pypi
    grpcio                    1.34.1                   pypi_0    pypi
    h5py                      3.1.0                    pypi_0    pypi
    huggingface-hub           0.23.2                   pypi_0    pypi
    hyperopt                  0.1.2                    pypi_0    pypi
    idna                      3.7                      pypi_0    pypi
    imageio                   2.34.1                   pypi_0    pypi
    importlib-metadata        7.1.0                    pypi_0    pypi
    importlib-resources       6.4.0                    pypi_0    pypi
    joblib                    1.4.2                    pypi_0    pypi
    jstyleson                 0.0.2                    pypi_0    pypi
    keras-nightly             2.5.0.dev2021032900          pypi_0    pypi
    keras-preprocessing       1.1.2                    pypi_0    pypi
    libedit                   3.1.20230828         h5eee18b_0    defaults
    libffi                    3.2.1             hf484d3e_1007    defaults
    libgcc-ng                 11.2.0               h1234567_1    defaults
    libgomp                   11.2.0               h1234567_1    defaults
    libstdcxx-ng              11.2.0               h1234567_1    defaults
    markdown                  3.6                      pypi_0    pypi
    markupsafe                2.1.5                    pypi_0    pypi
    ncurses                   6.4                  h6a678d5_0    defaults
    networkx                  2.8.8                    pypi_0    pypi
    nibabel                   5.1.0                    pypi_0    pypi
    nltk                      3.8.1                    pypi_0    pypi
    numpy                     1.19.5                   pypi_0    pypi
    oauthlib                  3.2.2                    pypi_0    pypi
    onnx                      1.13.0                   pypi_0    pypi
    opencv-python             4.5.5.64                 pypi_0    pypi
    openssl                   1.1.1w               h7f8727e_0    defaults
    openvino                  2021.4.2                 pypi_0    pypi
    openvino-dev              2021.4.2                 pypi_0    pypi
    openvino-telemetry        2024.1.0                 pypi_0    pypi
    openvino2tensorflow       1.34.0                   pypi_0    pypi
    opt-einsum                3.3.0                    pypi_0    pypi
    packaging                 24.0                     pypi_0    pypi
    pandas                    1.1.5                    pypi_0    pypi
    parasail                  1.3.4                    pypi_0    pypi
    pillow                    9.4.0                    pypi_0    pypi
    pip                       24.0             py38h06a4308_0    defaults
    progress                  1.6                      pypi_0    pypi
    promise                   2.3                      pypi_0    pypi
    protobuf                  3.20.3                   pypi_0    pypi
    psutil                    5.9.8                    pypi_0    pypi
    py-cpuinfo                9.0.0                    pypi_0    pypi
    pyasn1                    0.6.0                    pypi_0    pypi
    pyasn1-modules            0.4.0                    pypi_0    pypi
    pydicom                   2.4.4                    pypi_0    pypi
    pyinstaller               6.7.0                    pypi_0    pypi
    pyinstaller-hooks-contrib 2024.6                   pypi_0    pypi
    pymongo                   4.7.2                    pypi_0    pypi
    python                    3.8.0                h0371630_2    defaults
    python-dateutil           2.9.0.post0              pypi_0    pypi
    pytz                      2024.1                   pypi_0    pypi
    pywavelets                1.4.1                    pypi_0    pypi
    pyyaml                    6.0.1                    pypi_0    pypi
    rawpy                     0.21.0                   pypi_0    pypi
    readline                  7.0                  h7b6447c_5    defaults
    regex                     2024.5.15                pypi_0    pypi
    requests                  2.32.3                   pypi_0    pypi
    requests-oauthlib         2.0.0                    pypi_0    pypi
    rsa                       4.9                      pypi_0    pypi
    scikit-image              0.19.3                   pypi_0    pypi
    scikit-learn              1.3.2                    pypi_0    pypi
    scipy                     1.5.4                    pypi_0    pypi
    sentencepiece             0.2.0                    pypi_0    pypi
    setuptools                69.5.1           py38h06a4308_0    defaults
    shapely                   2.0.4                    pypi_0    pypi
    six                       1.15.0                   pypi_0    pypi
    soupsieve                 2.5                      pypi_0    pypi
    sqlite                    3.33.0               h62c20be_0    defaults
    tensorboard               2.11.2                   pypi_0    pypi
    tensorboard-data-server   0.6.1                    pypi_0    pypi
    tensorboard-plugin-wit    1.8.1                    pypi_0    pypi
    tensorflow                2.5.3                    pypi_0    pypi
    tensorflow-datasets       4.9.2                    pypi_0    pypi
    tensorflow-estimator      2.5.0                    pypi_0    pypi
    tensorflow-metadata       1.14.0                   pypi_0    pypi
    termcolor                 1.1.0                    pypi_0    pypi
    texttable                 1.6.7                    pypi_0    pypi
    threadpoolctl             3.5.0                    pypi_0    pypi
    tifffile                  2023.7.10                pypi_0    pypi
    tk                        8.6.14               h39e8969_0    defaults
    tokenizers                0.19.1                   pypi_0    pypi
    toml                      0.10.2                   pypi_0    pypi
    torch                     1.12.1                   pypi_0    pypi
    torchvision               0.13.1                   pypi_0    pypi
    tqdm                      4.66.4                   pypi_0    pypi
    typing-extensions         3.7.4.3                  pypi_0    pypi
    urllib3                   2.2.1                    pypi_0    pypi
    werkzeug                  3.0.3                    pypi_0    pypi
    wheel                     0.43.0           py38h06a4308_0    defaults
    wrapt                     1.12.1                   pypi_0    pypi
    xz                        5.4.6                h5eee18b_1    defaults
    yamlloader                1.4.1                    pypi_0    pypi
    zipp                      3.19.0                   pypi_0    pypi
    zlib                      1.2.13               h5eee18b_1    defaults
  • 2. 具体代码:(下面是int8量化)
    #!/usr/bin/env python
    """
    a command line tool to format onnx model from pytorch-onnx to tflite model
    
    """
    import random
    import os
    import tensorflow as tf
    import glob
    import cv2
    import numpy as np
    from tqdm import tqdm
    import argparse
    from pathlib import Path
    import shutil
    from typing import List
    
    
    def parse_args():
        parser = argparse.ArgumentParser(
            description="Formatting PyTorch models to TensorFlow models")
        parser.add_argument(
            "-i", "--input_onnx", type=str, help="an onnx file form pytorch model")
        parser.add_argument(
            "-s", "--shape", type=str, help="input image size (height, width)")
        parser.add_argument(
            "-o", "--output_dir", type=str, default="./", help="model output dir")
        parser.add_argument(
            "-t", "--tflite_file", type=str, help="output tflite file name")
        parser.add_argument(
            "-d", "--dataset", type=str, help="represent dataset")
        parser.add_argument(
            "-n", "--num_present_images", type=int, default=100,
            help="number of represent images for tflite quantization",
        )
        args = parser.parse_args()
        return args
    
    
    def convert_onnx2tensorflow(args):
        modify_xml_func = None
        if args.modify and args.modify_model == "yolox":
            modify_xml_func = mo_yolox_ov_xml
        return onnx2tensorflow(args, modify_xml_func)
    
    
    def onnx2tensorflow(args, modify_xml_func=None):
        print(f"CWD:{os.getcwd()}")
        output_dir = args.output_dir
        onnx_model = args.input_onnx
        h, w = eval(args.shape)
        input_shape = (1, 3, h, w)
        ov_dir = os.path.join(output_dir, "ov")
    
        shutil.rmtree(ov_dir, ignore_errors=True)
        ov_cmd = (
            f"mo  \
            --input_model {onnx_model} \
            --input_shape '{input_shape}' \
            --output_dir {ov_dir} \
            --progress \
            --data_type FP32"
        )
        print(ov_cmd)
        assert os.system(ov_cmd) == 0, "failed in converting onnx to openvino"
    
        ov_xml = os.path.join(ov_dir, f"{Path(onnx_model).stem}.xml")
        # add changes for certain models if needed
        if modify_xml_func:
            modify_xml_func(ov_xml)
    
        tf_model_dir = os.path.join(output_dir, "hwc")
        shutil.rmtree(tf_model_dir, ignore_errors=True)
        ov2tf_cmd = (
            f"openvino2tensorflow  \
            --model_path {ov_xml} \
            --model_output_path {tf_model_dir} \
            --output_no_quant_float32_tflite \
            --output_saved_model"
        )
        print(ov2tf_cmd)
        assert os.system(ov2tf_cmd) == 0, \
            "failed in converting openvino to tensorflow"
    
        return tf_model_dir
    
    
    def get_represent_images(path: str, num_present_images: int) -> List[str]:
        direc = Path(path)
        files = list(direc.rglob("*.jpg"))
        if not len(files) > 0:
            files = list(direc.rglob("*.png"))
        if not len(files) > 0:
            files = list(direc.rglob("*.JPEG"))
        if not len(files) > 0:
            raise TypeError("unrecognised img file type")
    
        files = random.sample(files, num_present_images)
        return files
    
    
    def tflite_quantize(args, tf_model_dir):
        """ convert a tensorflow model to tflite model with represent data
    
        :param args: related command line args
        :param tf_model_dir: where the tensorflow model saved
        :return: tflite model in current working dir
        """
        # assert args.tflite_file.split(".")[-1] == "tflite"
        # if os.path.exists(args.tflite_file):
        #     os.remove(args.tflite_file)
    
        files = get_represent_images(args.dataset, args.num_present_images)
        h, w = eval(args.shape)
    
        def representative_dataset():
            for file in tqdm(files):
                # read images in RGB format 
                # assume that the images were trained in RGB format
                # img = cv2.imread(str(file), cv2.IMREAD_GRAYSCALE)[..., ::-1]
                img = cv2.imread(str(file))[..., ::-1]
                img = cv2.resize(img, (w, h))
                # img = np.expand_dims(img, axis=-1)
                img = ((img - 127.5) / 127.5).astype(np.float32)
                img = img[None, ...]
                yield [img]
    
        converter = tf.lite.TFLiteConverter.from_saved_model(tf_model_dir)
        converter.optimizations = [tf.lite.Optimize.DEFAULT]    # 8bits weight quantization
        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
        converter.inference_input_type = tf.int8
        converter.inference_output_type = tf.int8
        converter.experimental_new_quantizer = False  # it must be false if TF version is not 2.4.1
        converter.representative_dataset = representative_dataset
        tflite_model = converter.convert()
    
        save_tflite_file = args.input_onnx.replace("onnx", "tflite")
        save_tflite_path = os.path.join(
                args.output_dir, os.path.basename(save_tflite_file),
        )
    
        with open(save_tflite_path, 'wb') as f:
            f.write(tflite_model)
    
        return save_tflite_path
    
    
    def mo_yolox_ov_xml(ov_xml):
        """
        special modifications for yolox model
        """
        with open(ov_xml) as f:
            xml = f.read()
        with open(ov_xml, "w") as f:
            f.write(xml.replace('<data axis="2"/>', '<data axis="1"/>'))
    
    
    if __name__ == "__main__":
        onnx_args = parse_args()
        tensorflow_model_dir = onnx2tensorflow(onnx_args)
        tflite_quantize(onnx_args, tensorflow_model_dir)
    
    """
    torch-onnx2tflite.py
    -i yolox-tiny.onnx \
    -s "(320, 320)" \
    -d yolox/datasets/COCO/val2017
    <-t yox_tiny.tflite [alt] >\
    """
  • 3. 其他:
    •   接着,我尝试使用pyinstaller将这个工具固化,生成一个exe来使用,但是似乎并不如意。因为它的工作机制是在cmd命令行输入代码,但是openvino2tensorflow需要添加到环境变量后才能使用,因此,如果不添加这个环境变量,生成的exe也没有什么用处。因此,这个exe的准备似乎没有用处。(我还没有找到解决方法,如果您有解决方法,请告诉我,不甚感激。)
posted @ 2024-06-05 14:11  张幼安  阅读(482)  评论(0编辑  收藏  举报