腾讯开源框架TNN调用过程
第一步:模型转换,按照github一步一步来就ok了~此处无坑
第二步:cmake建立vs工程,需要在cmakelist里面需要使用的accelerator,否则在getdevice会返回NULL值
第三步:调用
#include "tnn/utils/dims_vector_utils.h" #include "tnn/utils/blob_transfer_utils.h" #include "tnn/core/tnn.h" #include <fstream> static void ModifynhwcTonchw(float* dst, float* src, int batch, int channel, int height, int width) { for (int n = 0; n < batch; n++) { for (int c = 0; c < channel; c++) { for (int h = 0; h < height; h++) { for (int w = 0; w < width; w++) { dst[n*height*width*channel + c*height*width + h*width + w] = src[n*height*width*channel + h*width*channel + w*channel + c]; } } } } } void CopyDataToDeviceFromFile(TNN_NS::BlobMap blob_map,std::string input_file,void* command_queue) { //get input_blob info std::string input_name = (blob_map.begin())->first; TNN_NS::Blob* device_blob = (blob_map.begin())->second; TNN_NS::BlobConverter blob_converter(device_blob); TNN_NS::BlobDesc blob_desc = device_blob->GetBlobDesc(); //get input data TNN_NS::BlobHandle data_handle; int data_count = TNN_NS::DimsVectorUtils::Count(blob_desc.dims); float* input_data = (float*)malloc(data_count * sizeof(float)); FILE* fp = fopen(input_file.data(), "rb"); if (fp == NULL) { printf("CopyDataToDeviceFromFile Err,read input file failed: %s\n",input_file.data()); } fread(input_data, data_count, sizeof(float), fp); fclose(fp); //if necessary if (1) { float* trans_data = (float*)malloc(data_count * sizeof(float)); ModifynhwcTonchw(trans_data, input_data, blob_desc.dims[0], blob_desc.dims[1], blob_desc.dims[2], blob_desc.dims[3]); free(input_data); input_data = trans_data; } data_handle.base = input_data; data_handle.bytes_offset = 0; //convert TNN_NS::Blob data_blob(blob_desc,data_handle); TNN_NS::CopyToDevice(device_blob, &data_blob, command_queue); free(input_data); } void CopyDataFromDevicveToFile(TNN_NS::BlobMap blob_map, std::string out_file, void* command_queue) { //get output info TNN_NS::Blob* device_blob = (blob_map.begin())->second; TNN_NS::BlobConverter blob_converter_out(device_blob); TNN_NS::BlobDesc blob_desc = device_blob->GetBlobDesc(); int data_count = TNN_NS::DimsVectorUtils::Count(blob_desc.dims); //get input data TNN_NS::BlobHandle data_handle; float* input_data = (float*)malloc(data_count * sizeof(float)); data_handle.base = input_data; data_handle.bytes_offset = 0; //convert TNN_NS::Blob data_blob(blob_desc, data_handle); TNN_NS::CopyFromDevice(&data_blob, device_blob, command_queue); //write file FILE *fp = fopen(out_file.data(),"w"); for (int i = 0; i < data_count; i++) { fprintf(fp, "%f\n", input_data[i]); } fclose(fp); free(input_data); } int main() { std::string model_name = "test.opt.tnnmodel"; std::string bin_name = "test.opt.tnnproto"; std::string input_file = "input.txt"; std::string output_file = "data.txt"; TNN_NS::NetworkConfig myNet; TNN_NS::ModelConfig myModel; myModel.model_type = TNN_NS::MODEL_TYPE_TNN; myNet.device_type = TNN_NS::DEVICE_NAIVE; myNet.data_format = TNN_NS::DATA_FORMAT_NCHW; //read proto first std::ifstream proto_stream(bin_name); if (!proto_stream.is_open() || !proto_stream.good()) { printf("read proto_file failed!\n"); } auto buffer = std::string((std::istreambuf_iterator<char>(proto_stream)), std::istreambuf_iterator<char>()); myModel.params.push_back(buffer); //read model bin std::ifstream model_stream(model_name, std::ios::binary); if (!model_stream.is_open() || !model_stream.good()) { myModel.params.push_back(""); } auto model_content = std::string((std::istreambuf_iterator<char>(model_stream)), std::istreambuf_iterator<char>()); myModel.params.push_back(model_content); //Init TNN_NS::TNN net; TNN_NS::Status ret = net.Init(myModel); TNN_NS::InputShapesMap input_shape; auto instance = net.CreateInst(myNet, ret); TNN_NS::BlobMap input_blob_maps; TNN_NS::BlobMap output_blob_maps; void* command_queue; instance->GetAllInputBlobs(input_blob_maps); instance->GetAllOutputBlobs(output_blob_maps); instance->GetCommandQueue(&command_queue); CopyDataToDeviceFromFile(input_blob_maps,input_file, command_queue); ret = instance->Forward(); CopyDataFromDevicveToFile(output_blob_maps, output_file, command_queue); ret = net.DeInit(); }