linux 下 tensorflow c++ 调用pb模型预测实例(CVSample 鉴黄检测)
上次介绍了怎么使用java 调用CVSample 鉴黄项目检测模型,这次介绍C++ 怎么调用鉴黄检测模型。
tensorflow java 调用pb模型预测实例(CVSample 鉴黄检测)
CVSample 地址: https://github.com/kingroc711/CVSample/tree/master/TensorFlow/inception_model
不多介绍模型了,具体模型详细介绍可以看下: tensorflow java 调用pb模型预测实例(CVSample 鉴黄检测)
C++、Java 调用Tensorflow区别
个人感觉C++ 调用tensorflow 在搭建环境时,会比Java 麻烦。Java只需要Maven引入一个依赖,而C++ 需要先编译库,在整理需要引入的头文件,比Java麻烦很多。
在使用方便,和Java差不多,官网上都有示例,在用C++调用模型时,也是参照google官网示例,示例地址: https://gitee.com/mirrors/tensorflow/blob/v1.15.4/tensorflow/examples/label_image/main.cc
#include <stdlib.h> #include "tensorflow/core/public/session.h" using namespace tensorflow; using namespace std; static Status ReadEntireFile(tensorflow::Env* env, const string& filename, Tensor* output); int main(int argc, char *argv[]) { SessionOptions sessionOptions; Session *session = NewSession(sessionOptions); //pb文件路径 string modelPath = "/opt/work/build_work/TensorFlow/inception_model/output_graph.pb"; GraphDef graphDef; Status statud_load = ReadBinaryProto(Env::Default(), modelPath, &graphDef); if(statud_load.ok()) { cout << "load pb file success : " << modelPath << endl; } if( session->Create(graphDef).ok() ) { cout << "success graph in session " << endl; } vector<Tensor> outputs; string input = "DecodeJpeg/contents:0"; string output = "final_result:0"; Tensor input0(DT_STRING, TensorShape()); //图片文件 if(ReadEntireFile(tensorflow::Env::Default(), "/root/test.png", &input0).ok()) { cout << "图片读取成功!" << endl; } vector<pair<string, tensorflow::Tensor>> runInputs = { {"DecodeJpeg/contents:0", input0}, }; //预测 Status status = session->Run(runInputs, {output}, {}, &outputs); cout << status << endl; if (!status.ok()) { cout << "run failed!" << endl; } //处理输出结果,模型输出结果就是一维数组,按照索引0,1,2,3,4分别对应porn。neutral、hentai、drawings、sexy Tensor scores; scores = outputs[0]; tensorflow::TTypes<float>::Flat scores_flat = scores.flat<float>(); //scores_flat.size() 数量是 5,打印每一个分类分数 for(int i = 0; i < scores_flat.size(); i++) { cout << scores_flat(i) << endl; } return EXIT_SUCCESS; } //从官网示例中直接拿过来的 static Status ReadEntireFile(tensorflow::Env* env, const string& filename, Tensor* output) { tensorflow::uint64 file_size = 0; TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size)); string contents; contents.resize(file_size); std::unique_ptr<tensorflow::RandomAccessFile> file; TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file)); tensorflow::StringPiece data; TF_RETURN_IF_ERROR(file->Read(0, file_size, &data, &(contents)[0])); if (data.size() != file_size) { return tensorflow::errors::DataLoss("Truncated read of '", filename, "' expected ", file_size, " got ", data.size()); } output->scalar<tstring>()() = tstring(data); return Status::OK(); }