tensorrt学习(二)
继续整理tensorrt的学习资料,方便后续查看. (文章内容大部分摘取于网络资源)
1. tensorrt插件
对于tensorrt不支持的算子,可以通过plugin插件的方式,自己实现。这里采用在pytorch中自定义一个算子,导出到onnx中,然后实现一个tensorrt plugin来解析这个自定义算子。
1.1 pytorch中自定义onnx算子
官方文档:https://pytorch.org/docs/1.10/onnx.html#torch-autograd-functions
参考:https://zhuanlan.zhihu.com/p/513387413
继承torch.autograd.Function
类,实现其forward()和backward()方法,就可以当成一个普通的pytorch的函数在网络中使用,实现其symbolic
静态方法,当我们调用torch.onnx.export()时,就能将其转换为onnx算子,总结下如下:
- 对于模型推理和训练来说,
Function
类本身表示 PyTorch 的一个可导函数,只要为其定义了前向推理和反向传播的实现,我们就可以把它当成一个普通 PyTorch 函数来使用。PyTorch 会自动调度该函数,合适地执行前向和反向计算 - 对模型部署来说,
Function
类有一个很好的性质:如果它定义了symbolic
静态方法,该Function
在执行torch.onnx.export()
时就可以根据symbolic
中定义的规则转换成 ONNX 算子。 - symbolic是符号函数,通常在其内部返回一个g.op()对象。
g.op()
把一个 PyTorch 算子映射成一个或多个 ONNX 算子,或者是自定义的 ONNX 算子。
下面是实现一个selu激活函数的代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.onnx
import torch.autograd
import os
class MYSELUImpl(torch.autograd.Function):
# reference: https://pytorch.org/docs/1.10/onnx.html#torch-autograd-functions
@staticmethod
def symbolic(g, x, p):
print("==================================call symbolic")
return g.op("MYSELU", x, p,
g.op("Constant", value_t=torch.tensor([3, 2, 1], dtype=torch.float32)),
attr1_s="这是字符串属性",
attr2_i=[1, 2, 3],
attr3_f=222
)
@staticmethod
def forward(ctx, x, p):
return x * 1 / (1 + torch.exp(-x))
class MYSELU(nn.Module):
def __init__(self, n):
super().__init__()
self.param = nn.parameter.Parameter(torch.arange(n).float())
def forward(self, x):
return MYSELUImpl.apply(x, self.param)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 1, 3, padding=1)
self.myselu = MYSELU(3)
self.conv.weight.data.fill_(1)
self.conv.bias.data.fill_(0)
def forward(self, x):
x = self.conv(x)
x = self.myselu(x)
return x
# 这个包对应opset11的导出代码,如果想修改导出的细节,可以在这里修改代码
# import torch.onnx.symbolic_opset11
print("对应opset文件夹代码在这里:", os.path.dirname(torch.onnx.__file__))
model = Model().eval()
input = torch.tensor([
# batch 0
[
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
],
# batch 1
[
[-1, 1, 1],
[1, 0, 1],
[1, 1, -1]
]
], dtype=torch.float32).view(2, 1, 3, 3)
output = model(input)
print(f"inference output = \n{output}")
dummy = torch.zeros(1, 1, 3, 3)
torch.onnx.export(
model,
# 这里的args,是指输入给model的参数,需要传递tuple,因此用括号
(dummy,),
# 储存的文件路径
"workspace/demo.onnx",
# 打印详细信息
verbose=True,
# 为输入和输出节点指定名称,方便后面查看或者操作
input_names=["image"],
output_names=["output"],
# 这里的opset,指,各类算子以何种方式导出,对应于symbolic_opset11
opset_version=11,
# 表示他有batch、height、width3个维度是动态的,在onnx中给其赋值为-1
# 通常,我们只设置batch为动态,其他的避免动态
dynamic_axes={
"image": {0: "batch", 2: "height", 3: "width"},
"output": {0: "batch", 2: "height", 3: "width"},
},
# 对于插件,需要禁用onnx检查
enable_onnx_checker=False
)
print("Done.!")
上述返回的g.op()函数值得说明显下:
g.op("MYSELU", x, p, # 表示onnx算子的名称为MYSELU
# 给算子传一个常数参数
g.op("Constant", value_t=torch.tensor([3, 2, 1], dtype=torch.float32)),
attr1_s="这是字符串属性", # s表示字符串
attr2_i=[1, 2, 3], # i表示整数
attr3_f=222 # f表示浮点数
)
下面是导出onnx后MYSELU节点对应如下:(标红的即为g.op中对应的参数)
1.2 tensorrt plugin插件解析onnx算子
官方文档:https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#extending
tensorrt中自定义一个插件,需要继承和实现两个类, 然后注册这个插件的创建器。
-
1 继承
nvinfer1::IPluginV2DynamicExt
类,完成插件的具体实现class MySELUPlugin : public nvinfer1::IPluginV2DynamicExt {};
- 需要注意的是,实现
getPluginType()
函数时,其返回的名称要和onnx中该算子一致 - 插件算子的具体实现逻辑一般用cuda核函数重写,在
enqueue()
函数中调用核函数
- 需要注意的是,实现
-
2 继承
nvinfer1::IPluginCreator
类,是一个插件工厂类,用于插件的实例创建class MySELUPluginCreator : public nvinfer1::IPluginCreator {};
- 需要注意的是,实现
getPluginName()
函数时,其返回的名称要和onnx中该算子一致
- 需要注意的是,实现
-
3 采用宏
REGISTER_TENSORRT_PLUGIN
注册插件:REGISTER_TENSORRT_PLUGIN(MySELUPluginCreator);
IPluginV2DynamicExt
继承自IPluginV2Ext
, IPluginV2Ext
又继承自IPluginV2
,所以需要实现这三个基类的虚函数, 主要是下面几个:
IPluginV2DynamicExt
基类:- 构造函数和析构函数
- virtual DimsExprs getOutputDimensions():输出数据的尺寸
- virtual bool supportsFormatCombination():支持的数据类型,int8,float16,float32等
- virtual void configurePlugin(): 配置插件格式(这个算子所采用的数据格式和类型)
- virtual size_t getWorkspaceSize(): 需要的额外空间大小
- virtual int enqueue(): 推理具体逻辑
IPluginV2Ext
基类:- virtual nvinfer1::DataType getOutputDataType()
IPluginV2
基类:- virtual AsciiChar const* getPluginType()
- virtual AsciiChar const* getPluginVersion()
- virtual int32_t getNbOutputs()
- virtual size_t getSerializationSize()
- virtual void serialize(void* buffer)
IPluginCreato
基类,主要需要实现的虚函数如下:
- 构造函数和析构函数
- virtual AsciiChar const* getPluginName()
- virtual AsciiChar const* getPluginVersion()
- virtual PluginFieldCollection const* getFieldNames()
- virtual IPluginV2* createPlugin()
- virtual IPluginV2* deserializePlugin()
- virtual void setPluginNamespace()
- virtual AsciiChar const* getPluginNamespace()
下面是实现插件的代码:
myselu_plugin.hpp
:
#ifndef CUSTOM_MYSELU_PLUGIN_H
#define CUSTOM_MYSELU_PLUGIN_H
#include <NvInferPlugin.h>
#include <string>
#include <vector>
class MySELUPlugin : public nvinfer1::IPluginV2DynamicExt {
public:
MySELUPlugin(const std::string name, const std::string attr1, float attr3); // 接受算子名称属性,build engine时构造函数
MySELUPlugin(const std::string name, const void* data, size_t length); // 接受算子名称和反序列化的engine data,推理时构造函数
MySELUPlugin() = delete;
int getNbOutputs() const noexcept override;
virtual nvinfer1::DataType getOutputDataType(int32_t index,
nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept override {
return inputTypes[0];
}
virtual nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex,
const nvinfer1::DimsExprs* inputs, int32_t nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override;
int initialize() noexcept override;
void terminate() noexcept override;
virtual size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int32_t nbInputs, const nvinfer1::PluginTensorDesc* outputs,
int32_t nbOutputs) const noexcept override {
return 0;
};
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
size_t getSerializationSize() const noexcept override;
void serialize(void* buffer) const noexcept override;
virtual void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int32_t nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out, int32_t nbOutputs) noexcept override;
virtual bool supportsFormatCombination(int32_t pos, const nvinfer1::PluginTensorDesc* inOut, int32_t nbInputs,
int32_t nbOutputs) noexcept override;
const char* getPluginType() const noexcept override;
const char* getPluginVersion() const noexcept override;
void destroy() noexcept override;
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
void setPluginNamespace(nvinfer1::AsciiChar const* pluginNamespace) noexcept override;
const char* getPluginNamespace()const noexcept override;
private:
const std::string mLayerName;
std::string mattr1;
float mattr3;
size_t mInputVolume;
std::string mNamespace;
};
class MySELUPluginCreator : public nvinfer1::IPluginCreator {
public:
MySELUPluginCreator();
const char* getPluginName() const noexcept override;
const char* getPluginVersion() const noexcept override;
const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override;
nvinfer1::IPluginV2* createPlugin(nvinfer1::AsciiChar const* name,
nvinfer1::PluginFieldCollection const* fc) noexcept override;
nvinfer1::IPluginV2* deserializePlugin(nvinfer1::AsciiChar const* name,
void const* serialData, size_t serialLength)noexcept override;
void setPluginNamespace(nvinfer1::AsciiChar const* pluginNamespace) noexcept override;
const char* getPluginNamespace() const noexcept override;
private:
static nvinfer1::PluginFieldCollection mfc;
static std::vector<nvinfer1::PluginField> mPluginAttributes;
std::string mNamespace;
};
#endif
``myselu_plugin.cpp:
#include "myselu_plugin.hpp"
#include <NvInfer.h>
#include <cstring>
#include <vector>
#include <cassert>
void myselu_inference(const float* x, float* output, int n, cudaStream_t stream);
// MySELU plugin的特定常量
namespace {
const char* MYSELU_PLUGIN_VERSION{ "1" };
const char* MYSELU_PLUGIN_NAME{ "MYSELU" }; //名称要和onnx中对应的一致
}
// 静态类字段的初始化
nvinfer1::PluginFieldCollection MySELUPluginCreator::mfc{};
std::vector<nvinfer1::PluginField> MySELUPluginCreator::mPluginAttributes;
// 实际注册时,注册的是创建器,交给tensorRT管理
REGISTER_TENSORRT_PLUGIN(MySELUPluginCreator);
// 用于序列化插件的Helper function
template <typename T>
void writeToBuffer(char*& buffer, const T& val) {
*reinterpret_cast<T*>(buffer) = val;
buffer += sizeof(T);
}
// 用于反序列化插件的Helper function
template <typename T>
T readFromBuffer(const char*& buffer) {
T val = *reinterpret_cast<const T*>(buffer);
buffer += sizeof(T);
return val;
}
// 定义插件类MYSELUPlugin
MySELUPlugin::MySELUPlugin(const std::string name, const std::string attr1, float attr3)
:mLayerName(name), mattr1(attr1), mattr3(attr3)
{
printf("==================== 编译阶段,attr1 = %s, attr3 = %f\n", attr1.c_str(), attr3);
};
MySELUPlugin::MySELUPlugin(const std::string name, const void* data, size_t length)
:mLayerName(name)
{
// Deserialize in the same order as serialization
const char* d = static_cast<const char*>(data);
const char* a = d;
int nstr = readFromBuffer<int>(d);
mattr1 = std::string(d, d + nstr);
d += nstr;
mattr3 = readFromBuffer<float>(d);
assert(d == (a + length));
printf("==================== 推理阶段,attr1 = %s, attr3 = %f\n", mattr1.c_str(), mattr3);
};
const char* MySELUPlugin::getPluginType() const noexcept
{
return MYSELU_PLUGIN_NAME;
}
const char* MySELUPlugin::getPluginVersion() const noexcept
{
return MYSELU_PLUGIN_VERSION;
}
int MySELUPlugin::getNbOutputs() const noexcept {
return 1;
}
// 获取该层的输出维度是多少
nvinfer1::DimsExprs MySELUPlugin::getOutputDimensions(int32_t outputIndex,
const nvinfer1::DimsExprs* inputs, int32_t nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
{
// MySELUping不改变输入尺寸,所以输出尺寸将与输入尺寸相同
return *inputs;
}
int MySELUPlugin::initialize() noexcept
{
return 0;
}
int MySELUPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
{
void* output = outputs[0];
size_t volume = 1;
for (int i = 0; i < inputDesc->dims.nbDims; ++i) {
volume *= inputDesc->dims.d[i];
}
mInputVolume = volume;
myselu_inference(static_cast<const float*>(inputs[0]),
static_cast<float*>(output),
mInputVolume,
stream
);
return 0;
}
size_t MySELUPlugin::getSerializationSize() const noexcept
{
return sizeof(int) + mattr1.size() + sizeof(mattr3);
}
// 该层的参数序列化储存为trtmodel文件
void MySELUPlugin::serialize(void* buffer) const noexcept
{
char* d = static_cast<char*>(buffer);
const char* a = d;
int nstr = mattr1.size();
writeToBuffer(d, nstr);
memcpy(d, mattr1.data(), nstr);
d += nstr;
writeToBuffer(d, mattr3);
assert(d == a + getSerializationSize());
}
// 判断该插件所支持的数据格式和类型
bool MySELUPlugin::supportsFormatCombination(int32_t pos, const nvinfer1::PluginTensorDesc* inOut, int32_t nbInputs,
int32_t nbOutputs) noexcept
{
auto type = inOut[pos].type;
auto format = inOut[pos].format;
// 这个插件只支持普通的浮点数,以及NCHW输入格式
if (type == nvinfer1::DataType::kFLOAT && format == nvinfer1::PluginFormat::kLINEAR) {
return true;
}
else {
return false;
}
}
void MySELUPlugin::terminate() noexcept {}
void MySELUPlugin::destroy() noexcept
{
// This gets called when the network containing plugin is destroyed
delete this;
}
// 配置插件格式:目前这个层所采用的数据格式和类型
void MySELUPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int32_t nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out, int32_t nbOutputs) noexcept
{
auto type = in->desc.type;
auto format = in->desc.format;
assert(nbOutputs == 1);
assert(type == nvinfer1::DataType::kFLOAT);
assert(format == nvinfer1::PluginFormat::kLINEAR);
}
// 克隆插件
nvinfer1::IPluginV2DynamicExt* MySELUPlugin::clone() const noexcept
{
printf("===================克隆插件=================\n");
auto plugin = new MySELUPlugin(mLayerName, mattr1, mattr3);
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}
void MySELUPlugin::setPluginNamespace(const char* libNamespace) noexcept
{
mNamespace = libNamespace;
}
const char* MySELUPlugin::getPluginNamespace() const noexcept
{
return mNamespace.c_str();
}
// 插件创建器
MySELUPluginCreator::MySELUPluginCreator()
{
// 描述MySELUPlugin的必要PluginField参数
mPluginAttributes.emplace_back(nvinfer1::PluginField("attr1", nullptr, nvinfer1::PluginFieldType::kCHAR));
mPluginAttributes.emplace_back(nvinfer1::PluginField("attr3", nullptr, nvinfer1::PluginFieldType::kFLOAT32));
// 收集PluginField的参数
mfc.nbFields = mPluginAttributes.size();
mfc.fields = mPluginAttributes.data();
}
const char* MySELUPluginCreator::getPluginName() const noexcept
{
return MYSELU_PLUGIN_NAME;
}
const char* MySELUPluginCreator::getPluginVersion() const noexcept
{
return MYSELU_PLUGIN_VERSION;
}
const nvinfer1::PluginFieldCollection* MySELUPluginCreator::getFieldNames() noexcept
{
return &mfc;
}
// 创建plugin
nvinfer1::IPluginV2* MySELUPluginCreator::createPlugin(nvinfer1::AsciiChar const* name,
nvinfer1::PluginFieldCollection const* fc) noexcept
{
std::string attr1;
float attr3;
const nvinfer1::PluginField* fields = fc->fields;
// Parse fields from PluginFieldCollection
for (int i = 0; i < fc->nbFields; ++i) {
if (strcmp(fields[i].name, "attr1")==0) {
assert(fields[i].type == nvinfer1::PluginFieldType::kCHAR);
auto cp = static_cast<const char*>(fields[i].data);
attr1 = std::string(cp, cp + fields[i].length);
}
else if (strcmp(fields[i].name, "attr3") == 0) {
assert(fields[i].type == nvinfer1::PluginFieldType::kFLOAT32);
attr3 = *(static_cast<const float*>(fields[i].data));
}
}
return new MySELUPlugin(name, attr1, attr3);
}
// 反序列化插件参数进行创建
nvinfer1::IPluginV2* MySELUPluginCreator::deserializePlugin(nvinfer1::AsciiChar const* name,
void const* serialData, size_t serialLength)noexcept
{
// This object will be deleted when the network is destroyed, which will
// call MySELUPlugin::destroy()
return new MySELUPlugin(name, serialData, serialLength);
}
void MySELUPluginCreator::setPluginNamespace(const char* libNamespace) noexcept
{
mNamespace = libNamespace;
}
const char* MySELUPluginCreator::getPluginNamespace() const noexcept
{
return mNamespace.c_str();
}
核函数myselu_kernel.cu
#include <cuda_runtime.h>
#include <cmath>
static __device__ float sigmoid(float x) {
return 1 / (1 + expf(-x));
}
static __global__ void myselu_kernel(const float* x, float* output, int n)
{
int position = threadIdx.x + blockDim.x*blockIdx.x;
if (position >= n) return;
output[position] = x[position]*sigmoid(x[position]);
}
void myselu_inference(const float* x, float* output, int n, cudaStream_t stream)
{
const int nthreads = 512;
int block_size = n > nthreads ? nthreads : n;
int grid_size = (n + block_size - 1) / block_size;
myselu_kernel<<<grid_size, block_size, 0, stream>>>(x, output, n);
}
主函数main.cpp
// tensorRT include
// 编译用的头文件
#include <NvInfer.h>
// onnx解析器的头文件
#include <NvOnnxParser.h>
// 推理用的运行时头文件
#include <NvInferRuntime.h>
// cuda include
#include <cuda_runtime.h>
// system include
#include <stdio.h>
#include <math.h>
#include <iostream>
#include <fstream>
#include <vector>
using namespace std;
inline const char* severity_string(nvinfer1::ILogger::Severity t) {
switch (t) {
case nvinfer1::ILogger::Severity::kINTERNAL_ERROR: return "internal_error";
case nvinfer1::ILogger::Severity::kERROR: return "error";
case nvinfer1::ILogger::Severity::kWARNING: return "warning";
case nvinfer1::ILogger::Severity::kINFO: return "info";
case nvinfer1::ILogger::Severity::kVERBOSE: return "verbose";
default: return "unknow";
}
}
class TRTLogger : public nvinfer1::ILogger {
public:
virtual void log(Severity severity, nvinfer1::AsciiChar const* msg) noexcept override {
if (severity <= Severity::kINFO) {
// 打印带颜色的字符,格式如下:
// printf("\033[47;33m打印的文本\033[0m");
// 其中 \033[ 是起始标记
// 47 是背景颜色
// ; 分隔符
// 33 文字颜色
// m 开始标记结束
// \033[0m 是终止标记
// 其中背景颜色或者文字颜色可不写
// 部分颜色代码 https://blog.csdn.net/ericbar/article/details/79652086
if (severity == Severity::kWARNING) {
printf("\033[33m%s: %s\033[0m\n", severity_string(severity), msg);
}
else if (severity <= Severity::kERROR) {
printf("\033[31m%s: %s\033[0m\n", severity_string(severity), msg);
}
else {
printf("%s: %s\n", severity_string(severity), msg);
}
}
}
} logger;
bool build_model() {
TRTLogger logger;
// 这是基本需要的组件
nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(logger);
nvinfer1::IBuilderConfig* config = builder->createBuilderConfig();
nvinfer1::INetworkDefinition* network = builder->createNetworkV2(1);
// 通过onnxparser解析器解析的结果会填充到network中,类似addConv的方式添加进去
nvonnxparser::IParser* parser = nvonnxparser::createParser(*network, logger);
if (!parser->parseFromFile("myselu.onnx", 1)) {
network->destroy();
config->destroy();
builder->destroy();
printf("load onnx file failed\n");
return false;
}
int maxBatchSize = 10;
printf("Workspace Size = %.2f MB\n", (1 << 28) / 1024.0f / 1024.0f);
config->setMaxWorkspaceSize(1 << 28);
// 如果模型有多个输入,则必须多个profile
auto profile = builder->createOptimizationProfile();
auto input_tensor = network->getInput(0);
int input_channel = input_tensor->getDimensions().d[1];
// 配置输入的最小、最优、最大的范围
profile->setDimensions(input_tensor->getName(), nvinfer1::OptProfileSelector::kMIN, nvinfer1::Dims4(1, input_channel, 3, 3));
profile->setDimensions(input_tensor->getName(), nvinfer1::OptProfileSelector::kOPT, nvinfer1::Dims4(1, input_channel, 3, 3));
profile->setDimensions(input_tensor->getName(), nvinfer1::OptProfileSelector::kMAX, nvinfer1::Dims4(maxBatchSize, input_channel, 5, 5));
config->addOptimizationProfile(profile);
nvinfer1::ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config);
if (engine == nullptr) {
printf("build engine failed\n");
network->destroy();
config->destroy();
builder->destroy();
return false;
}
// 将模型序列化,并储存为文件
nvinfer1::IHostMemory* model_data = engine->serialize();
FILE* f = fopen("myselu.trtmodel", "wb");
fwrite(model_data->data(), 1, model_data->size(), f);
fclose(f);
// 卸载顺序按照构建顺序倒序
model_data->destroy();
parser->destroy();
engine->destroy();
network->destroy();
config->destroy();
builder->destroy();
printf("Done.\n");
return true;
}
///////////////////////////////////////////////////////////////////////////////////////////////////////
vector<unsigned char> load_file(const string& file) {
ifstream in(file, ios::in | ios::binary);
if (!in.is_open())
return {};
in.seekg(0, ios::end);
size_t length = in.tellg();
std::vector<uint8_t> data;
if (length > 0) {
in.seekg(0, ios::beg);
data.resize(length);
in.read((char*)&data[0], length);
}
in.close();
return data;
}
void inference() {
TRTLogger logger;
auto engine_data = load_file("myselu.trtmodel");
nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(logger);
nvinfer1::ICudaEngine* engine = runtime->deserializeCudaEngine(engine_data.data(), engine_data.size());
if (engine == nullptr) {
printf("Deserialize cuda engine failed.\n");
runtime->destroy();
return;
}
nvinfer1::IExecutionContext* execution_context = engine->createExecutionContext();
cudaStream_t stream = nullptr;
cudaStreamCreate(&stream);
float input_data_host[] = {
// batch 0
1, 1, 1,
1, 1, 1,
1, 1, 1,
// batch 1
-1, 1, 1,
1, 0, 1,
1, 1, -1
};
float* input_data_device = nullptr;
// 3x3输入,对应3x3输出
const int ib = 2;
const int iw = 3;
const int ih = 3;
float output_data_host[ib * iw * ih];
float* output_data_device = nullptr;
cudaMalloc(&input_data_device, sizeof(input_data_host));
cudaMalloc(&output_data_device, sizeof(output_data_host));
cudaMemcpyAsync(input_data_device, input_data_host, sizeof(input_data_host), cudaMemcpyHostToDevice, stream);
// 明确当前推理时,使用的数据输入大小
execution_context->setBindingDimensions(0, nvinfer1::Dims4(ib, 1, ih, iw));
float* bindings[] = { input_data_device, output_data_device };
bool success = execution_context->enqueueV2((void**)bindings, stream, nullptr);
cudaMemcpyAsync(output_data_host, output_data_device, sizeof(output_data_host), cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
for (int b = 0; b < ib; ++b) {
printf("batch %d. output_data_host = \n", b);
for (int i = 0; i < iw * ih; ++i) {
printf("%f, ", output_data_host[b * iw * ih + i]);
if ((i + 1) % iw == 0)
printf("\n");
}
}
printf("Clean memory\n");
cudaStreamDestroy(stream);
cudaFree(input_data_device);
cudaFree(output_data_device);
execution_context->destroy();
engine->destroy();
runtime->destroy();
}
int main() {
if (!build_model()) {
return -1;
}
inference();
std::cin.get();
return 0;
}
上述代码运行过程中,可以观察插件的运行阶段:
-
编译阶段
-
- 通过MySELUPluginCreator::createPlugin创建plugin
-
- 期间会调用MySELUPlugin::clone克隆插件
-
- 调用MySELUPlugin::supportsFormatCombination判断该插件所支持的数据格式和类型
- 在这里我们告诉引擎,本插件可以支持什么类型的推理
- 可以支持多种,例如fp32、fp16、int8等等
-
- 调用MySELUPlugin::getOutputDimensions获取该层的输出维度是多少
-
- 调用MySELUPlugin::enqueue进行性能测试(不是一定会执行)
- 如果支持多种,则会在多种里面进行实际测试,选择一个性能最好的配置
-
- 调用MySELUPlugin::configurePlugin配置插件格式
- 告诉你目前这个层所采用的数据格式和类型
-
- 调用MySELUPlugin::serialize将该层的参数序列化储存为trtmodel文件
-
-
推理阶段
-
- 通过MySELUPluginCreator::deserializePlugin反序列化插件参数进行创建
-
- 期间会调用MySELUPlugin::clone克隆插件
-
- 调用MySELUPlugin::configurePlugin配置当前插件使用的数据类型和格式
-
- 调用MySELUPlugin::enqueue进行推理
-