使用C++部署Keras或TensorFlow模型

本文介绍如何在C++环境中部署Keras或TensorFlow模型。

一、对于Keras,

第一步,使用Keras搭建、训练、保存模型。

model.save('./your_keras_model.h5')

第二步,冻结Keras模型。

from keras.models import load_model
import tensorflow as tf
from tensorflow.python.framework import graph_io
from keras import backend as K

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    from tensorflow.python.framework.graph_util import convert_variables_to_constants
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = convert_variables_to_constants(session, input_graph_def, output_names, freeze_var_names)
    return frozen_graph

K.set_learning_phase(0)
keras_model = load_model('./your_keras_model.h5')
print('Inputs are:', keras_model.inputs)
print('Outputs are:', keras_model.outputs)

frozen_graph = freeze_session(K.get_session(), output_names=[out.op.name for out in model.outputs])
graph_io.write_graph(frozen_graph, "./", "your_frozen_model.pb", as_text=False)

  

二、对于TensorFlow,

1、使用TensorFlow搭建、训练、保存模型。

saver = tf.train.Saver()
saver.save(sess, "./your_tf_model.ckpt")

2、冻结TensorFlow模型。

python freeze_graph.py --input_checkpoint=./your_tf_model.ckpt --output_graph=./your_frozen_model.pb --output_node_names=output_node

三、使用TensorFlow的C/C++接口调用冻结的模型。这里,我们向模型中输入一张经过opencv处理的图片。

#include "tensorflow/core/public/session.h"
#include "tensorflow/core/platform/env.h"
#include "opencv2/opencv.hpp"
#include <iostream>
using namespace tensorflow;

int main(int argc, char* argv[]){
    // tell the network that it is not training
    phaseTensor = Tensor(DT_BOOL, TensorShape());
    auto phaseTensorPointer = phaseTensor.tensor<bool, 0>();
    phaseTensorPointer(0) = false;

    // read the input image
    cv::Mat img = imread('./your_input_image.png', 0);
    input_image_height = img.size().height;
    input_image_width = img.size().width;
    input_image_channels = img.channels();
    imageTensor = Tensor(DT_FLOAT, TensorShape({1, input_image_height, input_image_width, input_image_channels}));

    // convert the image to a tensor
    float * imageTensorPointer = imageTensor.flat<float>().data();
    cv::Mat imageTensorMatWarpper(input_image_height, input_image_width, CV_32FC3, imageTensorPointer);
    img.convertTo(imageTensorMatWarpper, CV_32FC3);

    // construct the input
    string input_node_name1 = "input tesnor name1";
    string input_node_name2 = "input tensor name2";
    std::vector<std::pair<string, Tensor>> inputs;
    inputs = {{input_node_name1, phaseTensor}, {input_node_name2, imageTensor},};

    // start a new session
    Session* session;
    Status status = NewSession(SessionOptions(), &session);
    if (!status.ok()) {
        cout << "NewSession failed! " << status.error_message() << std::endl;
    }
    // read the frozen graph
    GraphDef graph_def;
    status = ReadBinaryProto(Env::Default(), "./your_frozen_model.pb", &graph_def);
    if (!status.ok()) {
        cout << "ReadBinaryProto failed! " << status.error_message() << std::endl;
    }
    // initialize the session graph
    status = session->Create(graph_def);
    if (!status.ok()) {
        cout << "Create failed! " << status.error_message() << std::endl;
    }

    // define the output
    string output_node_name1 = "output tensor name1";
    std::vector<tensorflow::Tensor> outputs;

    // run the graph
    tensorflow::Status status = session->Run(inputs, {output_node_name1}, {}, &outputs);
    if (!status.ok()) {
        cout << "Run failed! " << status.error_message() << std::endl;
    }

    // obtain the output
    Tensor output = std::move(outputs[0]);
    tensorflow::StringPiece tmpBuff = output.tensor_data();
    const float* final_output = reinterpret_cast<const float*>(tmpBuff.data());

    //for classification problems, the output_data is a tensor of shape [batch_size, class_num]
    /*
    auto scores = outputs[0].flat<float>();
    */
    session->Close();
    return 0;
}

  

posted @ 2018-12-25 17:14  南乡水  阅读(6487)  评论(4编辑  收藏  举报