关于模型训练有两种方法,一种是直接使用C++编写训练代码,可以做到搭建完整的网络模型,但是无法使用迁移学习,而迁移学习是目前训练样本几乎都会用到的方法,另一种是使用python代码训练好模型,并使用JIT技术,将python模型导出为C++可调用的模型,这里具体介绍第二种。(个人觉得还可以采用一种方式,即将pytorch模型作为一种Web Service以供各种客户端调用)
以ESRGAN的inference code(https://github.com/xinntao/ESRGAN)为例:
Python packages: pip install numpy opencv-python
直接run test,结果如下(我的版本有做一些改动,如增加FPS的计算等):
2.将PyTorch模型转换为Torch Script
第二个方法就是向模型添加显式注释,通知Torch Script编译器它可以直接解析和编译模型代码,受Torch Script语言强加的约束。
- 利用Tracing将模型转换为Torch Script
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | import torch import architecture as arch # An instance of your model. model = arch.RRDB_Net( 3 , 3 , 64 , 23 , gc = 32 , upscale = 4 , norm_type = None , act_type = 'leakyrelu' , \ mode = 'CNA' , res_scale = 1 , upsample_mode = 'upconv' ) model.load_state_dict(torch.load( './models/RRDB_ESRGAN_x4.pth' ), strict = True ) model. eval () # An example input you would normally provide to your model's forward() method. example = torch.rand( 64 , 3 , 3 , 3 ) # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing. traced_script_module = torch.jit.trace(model, example) output = traced_script_module(torch.ones( 64 , 3 , 3 , 3 )) traced_script_module.save( "./models/RRDB_ESRGAN_x4_000.pt" ) # The traced ScriptModule can now be evaluated identically to a regular PyTorch module print (output) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | (surper - resolution - pytorch) anpi - cn@anpi - cn:~ / workspace_min / Super - Resolution / ESRGAN$ python model_jit_converter.py tensor([[[[ 0.9618 , 1.0375 , 1.0242 , ..., 1.0049 , 1.0399 , 1.0255 ], [ 1.0199 , 0.9996 , 1.0096 , ..., 1.0269 , 1.0140 , 1.0267 ], [ 1.0290 , 1.0154 , 1.0161 , ..., 1.0201 , 1.0077 , 1.0298 ], ..., [ 1.0316 , 1.0139 , 1.0184 , ..., 1.0184 , 1.0179 , 1.0197 ], [ 1.0391 , 1.0174 , 1.0162 , ..., 1.0185 , 1.0443 , 1.0168 ], [ 1.0066 , 1.0186 , 0.9976 , ..., 1.0143 , 1.0066 , 1.0249 ]], [[ 1.0155 , 1.0491 , 1.0004 , ..., 0.9993 , 0.9828 , 0.9706 ], [ 0.9992 , 1.0149 , 1.0032 , ..., 0.9851 , 0.9937 , 0.9887 ], [ 0.9974 , 1.0106 , 1.0089 , ..., 1.0072 , 1.0074 , 1.0041 ], ..., [ 1.0130 , 1.0036 , 1.0059 , ..., 0.9979 , 1.0065 , 1.0133 ], [ 1.0066 , 0.9955 , 1.0034 , ..., 1.0030 , 0.9875 , 1.0011 ], [ 0.9788 , 0.9983 , 1.0113 , ..., 1.0106 , 1.0381 , 1.0248 ]], [[ 0.9570 , 0.9789 , 0.9720 , ..., 0.9920 , 0.9740 , 0.9940 ], [ 0.9522 , 1.0182 , 1.0109 , ..., 1.0181 , 1.0060 , 0.9842 ], [ 0.9872 , 1.0062 , 1.0112 , ..., 1.0172 , 1.0072 , 0.9803 ], ..., [ 1.0211 , 1.0119 , 1.0091 , ..., 1.0082 , 1.0339 , 1.0348 ], [ 0.9894 , 1.0227 , 1.0226 , ..., 0.9930 , 1.0258 , 1.0234 ], [ 0.9997 , 0.9755 , 0.9969 , ..., 1.0227 , 1.0308 , 1.0109 ]]], [[[ 0.9618 , 1.0375 , 1.0242 , ..., 1.0049 , 1.0399 , 1.0255 ], [ 1.0199 , 0.9996 , 1.0096 , ..., 1.0269 , 1.0140 , 1.0267 ], [ 1.0290 , 1.0154 , 1.0161 , ..., 1.0201 , 1.0077 , 1.0298 ], ..., [ 1.0316 , 1.0139 , 1.0184 , ..., 1.0184 , 1.0179 , 1.0197 ], [ 1.0391 , 1.0174 , 1.0162 , ..., 1.0185 , 1.0443 , 1.0168 ], [ 1.0066 , 1.0186 , 0.9976 , ..., 1.0143 , 1.0066 , 1.0249 ]], [[ 1.0155 , 1.0491 , 1.0004 , ..., 0.9993 , 0.9828 , 0.9706 ], [ 0.9992 , 1.0149 , 1.0032 , ..., 0.9851 , 0.9937 , 0.9887 ], [ 0.9974 , 1.0106 , 1.0089 , ..., 1.0072 , 1.0074 , 1.0041 ], ..., [ 1.0130 , 1.0036 , 1.0059 , ..., 0.9979 , 1.0065 , 1.0133 ], [ 1.0066 , 0.9955 , 1.0034 , ..., 1.0030 , 0.9875 , 1.0011 ], [ 0.9788 , 0.9983 , 1.0113 , ..., 1.0106 , 1.0381 , 1.0248 ]], [[ 0.9570 , 0.9789 , 0.9720 , ..., 0.9920 , 0.9740 , 0.9940 ], [ 0.9522 , 1.0182 , 1.0109 , ..., 1.0181 , 1.0060 , 0.9842 ], [ 0.9872 , 1.0062 , 1.0112 , ..., 1.0172 , 1.0072 , 0.9803 ], ..., [ 1.0211 , 1.0119 , 1.0091 , ..., 1.0082 , 1.0339 , 1.0348 ], [ 0.9894 , 1.0227 , 1.0226 , ..., 0.9930 , 1.0258 , 1.0234 ], [ 0.9997 , 0.9755 , 0.9969 , ..., 1.0227 , 1.0308 , 1.0109 ]]], [[[ 0.9618 , 1.0375 , 1.0242 , ..., 1.0049 , 1.0399 , 1.0255 ], [ 1.0199 , 0.9996 , 1.0096 , ..., 1.0269 , 1.0140 , 1.0267 ], [ 1.0290 , 1.0154 , 1.0161 , ..., 1.0201 , 1.0077 , 1.0298 ], ..., [ 1.0316 , 1.0139 , 1.0184 , ..., 1.0184 , 1.0179 , 1.0197 ], [ 1.0391 , 1.0174 , 1.0162 , ..., 1.0185 , 1.0443 , 1.0168 ], [ 1.0066 , 1.0186 , 0.9976 , ..., 1.0143 , 1.0066 , 1.0249 ]], [[ 1.0155 , 1.0491 , 1.0004 , ..., 0.9993 , 0.9828 , 0.9706 ], [ 0.9992 , 1.0149 , 1.0032 , ..., 0.9851 , 0.9937 , 0.9887 ], [ 0.9974 , 1.0106 , 1.0089 , ..., 1.0072 , 1.0074 , 1.0041 ], ..., [ 1.0130 , 1.0036 , 1.0059 , ..., 0.9979 , 1.0065 , 1.0133 ], [ 1.0066 , 0.9955 , 1.0034 , ..., 1.0030 , 0.9875 , 1.0011 ], [ 0.9788 , 0.9983 , 1.0113 , ..., 1.0106 , 1.0381 , 1.0248 ]], [[ 0.9570 , 0.9789 , 0.9720 , ..., 0.9920 , 0.9740 , 0.9940 ], [ 0.9522 , 1.0182 , 1.0109 , ..., 1.0181 , 1.0060 , 0.9842 ], [ 0.9872 , 1.0062 , 1.0112 , ..., 1.0172 , 1.0072 , 0.9803 ], ..., [ 1.0211 , 1.0119 , 1.0091 , ..., 1.0082 , 1.0339 , 1.0348 ], [ 0.9894 , 1.0227 , 1.0226 , ..., 0.9930 , 1.0258 , 1.0234 ], [ 0.9997 , 0.9755 , 0.9969 , ..., 1.0227 , 1.0308 , 1.0109 ]]], ..., [[[ 0.9618 , 1.0375 , 1.0242 , ..., 1.0049 , 1.0399 , 1.0255 ], [ 1.0199 , 0.9996 , 1.0096 , ..., 1.0269 , 1.0140 , 1.0267 ], [ 1.0290 , 1.0154 , 1.0161 , ..., 1.0201 , 1.0077 , 1.0298 ], ..., [ 1.0316 , 1.0139 , 1.0184 , ..., 1.0184 , 1.0179 , 1.0197 ], [ 1.0391 , 1.0174 , 1.0162 , ..., 1.0185 , 1.0443 , 1.0168 ], [ 1.0066 , 1.0186 , 0.9976 , ..., 1.0143 , 1.0066 , 1.0249 ]], [[ 1.0155 , 1.0491 , 1.0004 , ..., 0.9993 , 0.9828 , 0.9706 ], [ 0.9992 , 1.0149 , 1.0032 , ..., 0.9851 , 0.9937 , 0.9887 ], [ 0.9974 , 1.0106 , 1.0089 , ..., 1.0072 , 1.0074 , 1.0041 ], ..., [ 1.0130 , 1.0036 , 1.0059 , ..., 0.9979 , 1.0065 , 1.0133 ], [ 1.0066 , 0.9955 , 1.0034 , ..., 1.0030 , 0.9875 , 1.0011 ], [ 0.9788 , 0.9983 , 1.0113 , ..., 1.0106 , 1.0381 , 1.0248 ]], [[ 0.9570 , 0.9789 , 0.9720 , ..., 0.9920 , 0.9740 , 0.9940 ], [ 0.9522 , 1.0182 , 1.0109 , ..., 1.0181 , 1.0060 , 0.9842 ], [ 0.9872 , 1.0062 , 1.0112 , ..., 1.0172 , 1.0072 , 0.9803 ], ..., [ 1.0211 , 1.0119 , 1.0091 , ..., 1.0082 , 1.0339 , 1.0348 ], [ 0.9894 , 1.0227 , 1.0226 , ..., 0.9930 , 1.0258 , 1.0234 ], [ 0.9997 , 0.9755 , 0.9969 , ..., 1.0227 , 1.0308 , 1.0109 ]]], [[[ 0.9618 , 1.0375 , 1.0242 , ..., 1.0049 , 1.0399 , 1.0255 ], [ 1.0199 , 0.9996 , 1.0096 , ..., 1.0269 , 1.0140 , 1.0267 ], [ 1.0290 , 1.0154 , 1.0161 , ..., 1.0201 , 1.0077 , 1.0298 ], ..., [ 1.0316 , 1.0139 , 1.0184 , ..., 1.0184 , 1.0179 , 1.0197 ], [ 1.0391 , 1.0174 , 1.0162 , ..., 1.0185 , 1.0443 , 1.0168 ], [ 1.0066 , 1.0186 , 0.9976 , ..., 1.0143 , 1.0066 , 1.0249 ]], [[ 1.0155 , 1.0491 , 1.0004 , ..., 0.9993 , 0.9828 , 0.9706 ], [ 0.9992 , 1.0149 , 1.0032 , ..., 0.9851 , 0.9937 , 0.9887 ], [ 0.9974 , 1.0106 , 1.0089 , ..., 1.0072 , 1.0074 , 1.0041 ], ..., [ 1.0130 , 1.0036 , 1.0059 , ..., 0.9979 , 1.0065 , 1.0133 ], [ 1.0066 , 0.9955 , 1.0034 , ..., 1.0030 , 0.9875 , 1.0011 ], [ 0.9788 , 0.9983 , 1.0113 , ..., 1.0106 , 1.0381 , 1.0248 ]], [[ 0.9570 , 0.9789 , 0.9720 , ..., 0.9920 , 0.9740 , 0.9940 ], [ 0.9522 , 1.0182 , 1.0109 , ..., 1.0181 , 1.0060 , 0.9842 ], [ 0.9872 , 1.0062 , 1.0112 , ..., 1.0172 , 1.0072 , 0.9803 ], ..., [ 1.0211 , 1.0119 , 1.0091 , ..., 1.0082 , 1.0339 , 1.0348 ], [ 0.9894 , 1.0227 , 1.0226 , ..., 0.9930 , 1.0258 , 1.0234 ], [ 0.9997 , 0.9755 , 0.9969 , ..., 1.0227 , 1.0308 , 1.0109 ]]], [[[ 0.9618 , 1.0375 , 1.0242 , ..., 1.0049 , 1.0399 , 1.0255 ], [ 1.0199 , 0.9996 , 1.0096 , ..., 1.0269 , 1.0140 , 1.0267 ], [ 1.0290 , 1.0154 , 1.0161 , ..., 1.0201 , 1.0077 , 1.0298 ], ..., [ 1.0316 , 1.0139 , 1.0184 , ..., 1.0184 , 1.0179 , 1.0197 ], [ 1.0391 , 1.0174 , 1.0162 , ..., 1.0185 , 1.0443 , 1.0168 ], [ 1.0066 , 1.0186 , 0.9976 , ..., 1.0143 , 1.0066 , 1.0249 ]], [[ 1.0155 , 1.0491 , 1.0004 , ..., 0.9993 , 0.9828 , 0.9706 ], [ 0.9992 , 1.0149 , 1.0032 , ..., 0.9851 , 0.9937 , 0.9887 ], [ 0.9974 , 1.0106 , 1.0089 , ..., 1.0072 , 1.0074 , 1.0041 ], ..., [ 1.0130 , 1.0036 , 1.0059 , ..., 0.9979 , 1.0065 , 1.0133 ], [ 1.0066 , 0.9955 , 1.0034 , ..., 1.0030 , 0.9875 , 1.0011 ], [ 0.9788 , 0.9983 , 1.0113 , ..., 1.0106 , 1.0381 , 1.0248 ]], [[ 0.9570 , 0.9789 , 0.9720 , ..., 0.9920 , 0.9740 , 0.9940 ], [ 0.9522 , 1.0182 , 1.0109 , ..., 1.0181 , 1.0060 , 0.9842 ], [ 0.9872 , 1.0062 , 1.0112 , ..., 1.0172 , 1.0072 , 0.9803 ], ..., [ 1.0211 , 1.0119 , 1.0091 , ..., 1.0082 , 1.0339 , 1.0348 ], [ 0.9894 , 1.0227 , 1.0226 , ..., 0.9930 , 1.0258 , 1.0234 ], [ 0.9997 , 0.9755 , 0.9969 , ..., 1.0227 , 1.0308 , 1.0109 ]]]], grad_fn = <MkldnnConvolutionBackward>) |
3.在C++中加载你的Script Module
要在C ++中加载序列化的PyTorch模型,您的应用程序必须依赖于PyTorch C ++ API - 也称为LibTorch。LibTorch发行版包含一组共享库,头文件和CMake构建配置文件。虽然CMake不是依赖LibTorch的要求,但它是推荐的方法,并且将来会得到很好的支持。在本教程中,我们将使用CMake和LibTorch构建一个最小的C ++应用程序,它只需加载并执行序列化的PyTorch模型。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | #include <torch/script.h> // One-stop header. #include <iostream> #include <memory> int main( int argc, const char * argv[]) { if (argc != 2) { std::cerr << "usage: example-app <path-to-exported-script-module>\n" ; return -1; } // Deserialize the ScriptModule from a file using torch::jit::load(). std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]); assert (module != nullptr ); std::cout << "ok\n" ; } |
<torch / script.h>
头文件包含运行该示例所需的LibTorch库中的所有相关包含。我们的应用程序接受序列化PyTorch ScriptModule的文件路径作为其唯一的命令行参数,然后使用torch :: jit :: load()
函数继续反序列化模块,该函数将此文件路径作为输入。作为回报,我们收到一个指向torch :: jit :: script :: Module
的共享指针,相当于C ++中的torch.jit.ScriptModule
1 2 3 4 5 6 7 8 | cmake_minimum_required(VERSION 3.0 FATAL_ERROR) project(custom_ops) find_package(Torch REQUIRED) add_executable(example-app example-app.cpp) target_link_libraries(example-app "${TORCH_LIBRARIES}" ) set_property(TARGET example-app PROPERTY CXX_STANDARD 11) |
1 2 3 | example-app/ CMakeLists.txt example-app.cpp |
1 2 | cmake -DCMAKE_PREFIX_PATH=/home/anpi-cn/workspace_min/libtorch make |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | (surper-resolution-pytorch) anpi-cn@anpi-cn:~/workspace_min/Super-Resolution/ESRGAN/example-app$ cmake -DCMAKE_PREFIX_PATH=/home/anpi-cn/workspace_min/libtorch -- The C compiler identification is GNU 5.4.0 -- The CXX compiler identification is GNU 5.4.0 -- Check for working C compiler: /usr/bin/cc -- Check for working C compiler: /usr/bin/cc -- works -- Detecting C compiler ABI info -- Detecting C compiler ABI info - done -- Detecting C compile features -- Detecting C compile features - done -- Check for working CXX compiler: /usr/bin/c++ -- Check for working CXX compiler: /usr/bin/c++ -- works -- Detecting CXX compiler ABI info -- Detecting CXX compiler ABI info - done -- Detecting CXX compile features -- Detecting CXX compile features - done -- Looking for pthread.h -- Looking for pthread.h - found -- Looking for pthread_create -- Looking for pthread_create - not found -- Looking for pthread_create in pthreads -- Looking for pthread_create in pthreads - not found -- Looking for pthread_create in pthread -- Looking for pthread_create in pthread - found -- Found Threads: TRUE -- Found CUDA: /usr/local/cuda (found version "9.0" ) -- Caffe2: CUDA detected: 9.0 -- Caffe2: CUDA nvcc is: /usr/local/cuda/bin/nvcc -- Caffe2: CUDA toolkit directory: /usr/local/cuda -- Caffe2: Header version is: 9.0 -- Found CUDNN: /usr/include -- Found cuDNN: v7.4.1 (include: /usr/include, library: /usr/lib/x86_64-linux-gnu/libcudnn.so) -- Autodetected CUDA architecture(s): 6.1 -- Added CUDA NVCC flags for : -gencode;arch=compute_61,code=sm_61 -- Found torch: /home/anpi-cn/workspace_min/libtorch/lib/libtorch.so -- Configuring done CMake Warning at CMakeLists.txt:6 (add_executable): Cannot generate a safe runtime search path for target example-app because there is a cycle in the constraint graph: dir 0 is [/home/anpi-cn/workspace_min/libtorch/lib] dir 1 is [/usr/local/cuda/lib64/stubs] dir 2 is [/home/anpi-cn/.conda/envs/surper-resolution-pytorch/lib] dir 3 must precede it due to runtime library [libcudart.so.9.0] dir 3 is [/usr/local/cuda/lib64] dir 2 must precede it due to runtime library [libnvrtc.so.9.0] Some of these libraries may not be found correctly. -- Generating done -- Build files have been written to: /home/anpi-cn/workspace_min/Super-Resolution/ESRGAN/example-app (surper-resolution-pytorch) anpi-cn@anpi-cn:~/workspace_min/Super-Resolution/ESRGAN/example-app$ make Scanning dependencies of target example-app [ 50%] Building CXX object CMakeFiles/example-app.dir/example-app.cpp.o [100%] Linking CXX executable example-app [100%] Built target example-app (surper-resolution-pytorch) anpi-cn@anpi-cn:~/workspace_min/Super-Resolution/ESRGAN/example-app$ ./example-app ../models/RRDB_ESRGAN_x4_000.pt ok |
4.在C++代码中执行Script Module
在C ++中成功加载了我们的序列化模型后,添加以下代码到C ++应用程序的main()
1 2 3 4 5 6 7 8 | // Create a vector of inputs. std::vector<torch::jit::IValue> inputs; inputs.push_back(torch::ones({64, 3, 3, 3})); // Execute the model and turn its output into a tensor. auto output = module->forward(inputs).toTensor(); std::cout << output.slice( /*dim=*/ 1, /*start=*/ 0, /*end=*/ 5) << '\n' ; |
前两行设置了我们模型的输入。我们创建了一个torch :: jit :: IValue
的向量并添加一个输入。要创建输入张量,我们使用torch :: ones()
,相当于C ++ API中的torch.ones
PyTorch 1.0 中文官方教程:使用 PyTorch C++ 前端
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· DeepSeek 开源周回顾「GitHub 热点速览」
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
2015-07-16 leetcode:Delete Node in a Linked List
2015-07-16 leetcode:House Robber(动态规划dp1)