linux 下 tensorflow c++ 调用pb模型预测实例(CVSample 鉴黄检测)

上次介绍了怎么使用java 调用CVSample 鉴黄项目检测模型,这次介绍C++ 怎么调用鉴黄检测模型。

tensorflow java 调用pb模型预测实例(CVSample 鉴黄检测)

CVSample 地址: https://github.com/kingroc711/CVSample/tree/master/TensorFlow/inception_model

不多介绍模型了,具体模型详细介绍可以看下: tensorflow java 调用pb模型预测实例(CVSample 鉴黄检测)

C++、Java 调用Tensorflow区别

个人感觉C++ 调用tensorflow 在搭建环境时,会比Java 麻烦。Java只需要Maven引入一个依赖,而C++ 需要先编译库,在整理需要引入的头文件,比Java麻烦很多。

在使用方便,和Java差不多,官网上都有示例,在用C++调用模型时,也是参照google官网示例,示例地址: https://gitee.com/mirrors/tensorflow/blob/v1.15.4/tensorflow/examples/label_image/main.cc

#include <stdlib.h>
#include "tensorflow/core/public/session.h"

using namespace tensorflow;
using namespace std;
static Status ReadEntireFile(tensorflow::Env* env, const string& filename,
                             Tensor* output);
int main(int argc, char *argv[])
{
    SessionOptions sessionOptions;
    Session *session = NewSession(sessionOptions);
    //pb文件路径
    string modelPath = "/opt/work/build_work/TensorFlow/inception_model/output_graph.pb";
    GraphDef graphDef;
    Status statud_load = ReadBinaryProto(Env::Default(), modelPath, &graphDef);
    if(statud_load.ok()) {
        cout << "load pb file success : " << modelPath << endl;
    }


    if( session->Create(graphDef).ok() ) {
        cout << "success graph in session " << endl;
    }
    vector<Tensor> outputs;
    string input = "DecodeJpeg/contents:0";
    string output = "final_result:0";
    Tensor input0(DT_STRING, TensorShape());
    //图片文件
    if(ReadEntireFile(tensorflow::Env::Default(), "/root/test.png", &input0).ok()) {
        cout << "图片读取成功!" << endl;
    }
    vector<pair<string, tensorflow::Tensor>> runInputs = {
        {"DecodeJpeg/contents:0", input0},
    };
    //预测
    Status status = session->Run(runInputs, {output}, {}, &outputs);
    cout << status << endl;
    if (!status.ok()) {
        cout << "run failed!" << endl;
    }
    //处理输出结果,模型输出结果就是一维数组,按照索引0,1,2,3,4分别对应porn。neutral、hentai、drawings、sexy
    Tensor scores;
    scores = outputs[0];
    tensorflow::TTypes<float>::Flat scores_flat = scores.flat<float>();
    //scores_flat.size() 数量是 5,打印每一个分类分数
    for(int i = 0; i < scores_flat.size(); i++) {
        cout << scores_flat(i) << endl;
    }

    return EXIT_SUCCESS;
}

//从官网示例中直接拿过来的
static Status ReadEntireFile(tensorflow::Env* env, const string& filename,
                             Tensor* output) {
  tensorflow::uint64 file_size = 0;
  TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));

  string contents;
  contents.resize(file_size);

  std::unique_ptr<tensorflow::RandomAccessFile> file;
  TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));

  tensorflow::StringPiece data;
  TF_RETURN_IF_ERROR(file->Read(0, file_size, &data, &(contents)[0]));
  if (data.size() != file_size) {
    return tensorflow::errors::DataLoss("Truncated read of '", filename,
                                        "' expected ", file_size, " got ",
                                        data.size());
  }
  output->scalar<tstring>()() = tstring(data);
  return Status::OK();
}

 

posted @ 2022-03-26 18:10  耿明岩  阅读(399)  评论(0编辑  收藏  举报
希望能帮助到你,顺利解决问题! ...G(^_−)☆