c++ 使用torchscript 加载训练好的pytorch模型

1.首先官网上下载libtorch,放到当前项目下

2.将pytorch训练好的模型使用torch.jit.trace导出为.pt格式

 1 import torch
 2 from skimage import io, transform, color
 3 import numpy as np
 4 import os
 5 import torch.nn.functional as F
 6 import warnings
 7 warnings.filterwarnings("ignore")
 8 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 9 
10 labels = ['cock', 'drawing', 'neutral', 'porn', 'sexy']
11 path = "test/n_1.jpg"
12 im = io.imread(path)
13 if im.shape[2] == 4:
14     im = color.rgba2rgb(im)
15 
16 im = transform.resize(im, (224, 224))
17 im = np.transpose(im, (2, 0, 1))
18 dummy_input = np.expand_dims(im, 0)
19 inp = torch.from_numpy(dummy_input)
20 inp = inp.float()
21 model = torch.load(
22     "models/resnet50-epoch-0-accu-0.9213857428381079.pth", map_location='cpu')
23 traced_script_module = torch.jit.trace(model, inp)
24 output = model(inp)
25 probs = F.softmax(output).detach().numpy()[0]
26 pred = np.argmax(probs)
27 
28 traced_script_module.save("models/traced_resnet_model.pt")

torchscript加载.pt模型

  1 // One-stop header.
  2 #include <torch/script.h>
  3 
  4 // headers for opencv
  5 #include <opencv2/highgui/highgui.hpp>
  6 #include <opencv2/imgproc/imgproc.hpp>
  7 #include <opencv2/opencv.hpp>
  8 
  9 #include <cmath>
 10 #include <iostream>
 11 #include <memory>
 12 #include <string>
 13 #include <vector>
 14 
 15 #define kIMAGE_SIZE 224
 16 #define kCHANNELS 3
 17 #define kTOP_K 1 //print top k predicted results
 18 
 19 bool LoadImage(std::string file_name, cv::Mat &image)
 20 {
 21   image = cv::imread(file_name); // CV_8UC3
 22   if (image.empty() || !image.data)
 23   {
 24     return false;
 25   }
 26   cv::cvtColor(image, image, CV_BGR2RGB);
 27   // scale image to fit
 28   cv::Size scale(kIMAGE_SIZE, kIMAGE_SIZE);
 29   cv::resize(image, image, scale);
 30 
 31   // convert [unsigned int] to [float]
 32   image.convertTo(image, CV_32FC3,1.0/255);
 33 
 34   return true;
 35 }
 36 
 37 bool LoadImageNetLabel(std::string file_name,
 38                        std::vector<std::string> &labels)
 39 {
 40   std::ifstream ifs(file_name);
 41   if (!ifs)
 42   {
 43     return false;
 44   }
 45   std::string line;
 46   while (std::getline(ifs, line))
 47   {
 48     labels.push_back(line);
 49   }
 50   return true;
 51 }
 52 
 53 int main(int argc, const char *argv[])
 54 {
 55   if (argc != 3)
 56   {
 57     std::cerr << "Usage:classifier <path-to-exported-script-module>  <path-to-lable-file> " << std::endl;
 58     return -1;
 59   }
 60 
 61   //load model
 62   torch::jit::script::Module module = torch::jit::load(argv[1]);
 63   // to GPU
 64   // module->to(at::kCUDA);
 65   std::cout << "== ResNet50 loaded!\n";
 66 
 67   //load labels(classes names)
 68   std::vector<std::string> labels;
 69   if (LoadImageNetLabel(argv[2], labels))
 70   {
 71     std::cout << "== Label loaded! Let's try it\n";
 72   }
 73   else
 74   {
 75     std::cerr << "Please check your label file path." << std::endl;
 76     return -1;
 77   }
 78 
 79   std::string file_name = "";
 80   cv::Mat image;
 81   while (true)
 82   {
 83     std::cout << "== Input image path: [enter q to exit]" << std::endl;
 84     std::cin >> file_name;
 85     if (file_name == "Q" || file_name == "q")
 86     {
 87       break;
 88     }
 89     if (LoadImage(file_name, image))
 90     {
 91       //read image tensor
 92       auto input_tensor = torch::from_blob(
 93           image.data, {1, kIMAGE_SIZE, kIMAGE_SIZE, kCHANNELS});
 94       input_tensor = input_tensor.permute({0, 3, 1, 2});
 95       input_tensor[0][0] = input_tensor[0][0].sub_(0.485).div_(0.229);
 96       input_tensor[0][1] = input_tensor[0][1].sub_(0.456).div_(0.224);
 97       input_tensor[0][2] = input_tensor[0][2].sub_(0.406).div_(0.225);
 98       // to GPU
 99       // input_tensor = input_tensor.to(at::kCUDA);
100 
101       torch::Tensor out_tensor = module.forward({input_tensor}).toTensor();
102 
103       auto results = out_tensor.sort(-1, true);
104       auto softmaxs = std::get<0>(results)[0].softmax(0);
105       auto indexs = std::get<1>(results)[0];
106 
107       for (int i = 0; i < kTOP_K; ++i)
108       {
109         auto idx = indexs[i].item<int>();
110         std::cout << "    ============= Top-" << i + 1 << " =============" << std::endl;
111         std::cout << "    Label:  " << labels[idx] << std::endl;
112         std::cout << "    With Probability:  "
113                   << softmaxs[i].item<float>() * 100.0f << "%" << std::endl;
114       }
115     }
116     else
117     {
118       std::cout << "Can't load the image, please check your path." << std::endl;
119     }
120   }
121 }

CMakeLists.txt编译

 1 cmake_minimum_required(VERSION 2.8)
 2 project(predict_demo)
 3 SET(CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS} "-std=c++11 -O3")
 4 
 5 
 6 set(OpenCV_DIR  /home/buyizhiyou/opencv-3.4.4/build)
 7 find_package(OpenCV REQUIRED)
 8 find_package(Torch REQUIRED)
 9 
10 
11 # 添加头文件
12 include_directories( ${OpenCV_INCLUDE_DIRS} )
13 
14 add_executable(resnet_demo resnet_demo.cpp)
15 target_link_libraries(resnet_demo ${TORCH_LIBRARIES}  ${OpenCV_LIBS})
16 set_property(TARGET resnet_demo PROPERTY CXX_STANDARD 11)

运行

./resnet_demo   models/traced_resnet_model.pt  labels.txt

 

posted @ 2019-12-04 10:34  阿夏z  阅读(3949)  评论(1编辑  收藏  举报