MxNet 模型转Tensorflow pb模型

用mmdnn实现模型转换

参考链接:https://www.twblogs.net/a/5ca4cadbbd9eee5b1a0713af

  1. 安装mmdnn
    pip install mmdnn

     

  2. 准备好mxnet模型的.json文件和.params文件, 以InsightFace MxNet r50为例        https://github.com/deepinsight/insightface
  3. 用mmdnn运行命令行
    python -m mmdnn.conversion._script.convertToIR -f mxnet -n model-symbol.json -w model-0000.params -d resnet50 --inputShape 3,112,112 

     

     会生成resnet50.json(可视化文件) resnet50.npy(权重参数) resnet50.pb(网络结构)三个文件。

  4. 用mmdnn运行命令行
    python -m mmdnn.conversion._script.IRToCode -f tensorflow --IRModelPath resnet50.pb --IRWeightPath resnet50.npy --dstModelPath tf_resnet50.py 

     

     生成tf_resnet50.py文件,可以调用tf_resnet50.py中的KitModel函数加载npy权重参数重新生成原网络框架。

  5. 打开tf_resnet.py文件,修改load_weights()中的代码 (tensorflow=1.14.0报错) 

     try:
            weights_dict = np.load(weight_file).item()
        except:
            weights_dict = np.load(weight_file, encoding='bytes').item()

    改为

     try:
            weights_dict = np.load(weight_file, allow_pickle=True).item()
    except:
            weights_dict = np.load(weight_file, allow_pickle=True, encoding='bytes').item()

     

  6. 基于resnet50.npy和tf_resnet50.py文​​件,固化参数,生成PB文件:

    import tensorflow as tf
    import tf_resnet50 as tf_fun
    def netWork():
        model=tf_fun.KitModel("./resnet50.npy")
        return model
    def freeze_graph(output_graph):
        output_node_names = "output"
        data,fc1=netWork()
        fc1=tf.identity(fc1,name="output")
    
        graph = tf.get_default_graph()  # 獲得默認的圖
        input_graph_def = graph.as_graph_def()  # 返回一個序列化的圖代表當前的圖
        init = tf.global_variables_initializer()
        with tf.Session() as sess:
            sess.run(init)
            output_graph_def = tf.graph_util.convert_variables_to_constants(  # 模型持久化,將變量值固定
                sess=sess,
                input_graph_def=input_graph_def,  # 等於:sess.graph_def
                output_node_names=output_node_names.split(","))  # 如果有多個輸出節點,以逗號隔開
    
            with tf.gfile.GFile(output_graph, "wb") as f:  # 保存模型
                f.write(output_graph_def.SerializeToString())  # 序列化輸出
    
    if __name__ == '__main__':
        freeze_graph("frozen_insightface_r50.pb")
        print("finish!")

     

  7. 采用tensorflow的post-train quantization离线量化方法(有一定的精度损失)转换成tflite模型,从而完成端侧的模型部署:
    import tensorflow as tf
    
    convert=tf.lite.TFLiteConverter.from_frozen_graph("frozen_insightface_r50.pb",input_arrays=["data"],output_arrays=["output"],
                                                      input_shapes={"data":[1,112,112,3]})
    convert.post_training_quantize=True
    tflite_model=convert.convert()
    open("quantized_insightface_r50.tflite","wb").write(tflite_model)
    print("finish!")

     

posted on 2019-07-04 18:42  七昂的技术之旅  阅读(3496)  评论(0编辑  收藏  举报

导航