深度学习模型转换,以pytorch转tensorflow为例

这里以onnx为中介进行转换。主要用到

STEP1. 将pytorch 模型转换成onnx模型

注意这里关键是要构造一个模型的输入输入,这里假设模型接受两个输入。

pmodel = PytorchModel()
dummy_input = (np.zeros((1, 30), dtype=np.float32), np.zeros((1, 2), dtype=np.float32))
torch.onnx.export(pmodel, (torch.as_tensor(dummy_input[0]), torch.as_tensor(dummy_input[1])), "/tmp/xx.onnx",
                  verbose=True, input_names=['input1', 'input2'], output_names=['output1', 'output2'])

参数 input_names表示模型的输入参数(随便起名字),output_names表示输出名字

STEP 2. 将onnx模型转成tf

这里需要借助onnx_tf这个库

import onnx
from onnx_tf.backend import prepare

onnx_model = onnx.load("/tmp/xx.onnx")  # load onnx model
tf_model = prepare(onnx_model)
tf_model.export_graph("/tmp/xxpb/")  # export the model

STEP 3 使用tensorflow模型

import tensorflow as tf
import io
import numpy as np

model_path = '/tmp/xxpb/'

sess = tf.compat.v1.Session()
metagraph = tf.compat.v1.saved_model.loader.load(sess, [tf.compat.v1.saved_model.tag_constants.SERVING], model_path)
sig = metagraph.signature_def["serving_default"]
input_dict = dict(sig.inputs)
output_dict = dict(sig.outputs)
print(input_dict, output_dict)
output_stochastic_act_label_0 = output_dict["output_0"].name
output_stochastic_act_label_1 = output_dict["output_1"].name

input_state_label = None
initial_state = None
state = None
if "state" in input_dict.keys():
    input_state_label = input_dict["state"].name
    strfile = io.StringIO()
    print(input_dict["state"].tensor_shape, file=strfile)
    lines = strfile.getvalue().split("\n")
    dim_1 = int(lines[1].split(":")[1].strip(" "))
    dim_2 = int(lines[4].split(":")[1].strip(" "))
    initial_state = np.zeros((dim_1, dim_2), dtype=np.float32)
    state = np.zeros((dim_1, dim_2), dtype=np.float32)
input_obs_label_1 = input_dict["input1"].name
input_obs_label_0 = input_dict["input2"].name
input_dict = {input_obs_label_0: np.zeros((1, 2), dtype=np.float32), input_obs_label_1:np.zeros((1, 30), dtype=np.float32)}
out = sess.run((output_stochastic_act_label_0, output_stochastic_act_label_1), feed_dict=input_dict)
print(out)

注意这里的name需要重新设置一遍。





posted @ 2020-12-16 19:50  mrbean  阅读(4020)  评论(0编辑  收藏  举报