(原)pytorch中使用TensorRT
转载请注明出处:
https://www.cnblogs.com/darkknightzh/p/11332155.html
代码网址:
https://github.com/darkknightzh/TensorRT_pytorch
参考网址:
tensorrt安装包的sample/python目录
https://github.com/pytorch/examples/tree/master/mnist
此处代码使用的是tensorrt5.1.5
在安装完tensorrt之后,使用tensorrt主要包括下面几段代码:
1. 初始化
import tensorrt as trt import pycuda.driver as cuda import pycuda.autoinit # 此句代码中未使用,但是必须有。this is useful, otherwise stream = cuda.Stream() will cause 'explicit_context_dependent failed: invalid device context - no currently active context?'
如注解所示,import pycuda.autoinit这句话程序中未使用,但是必须包含,否则程序运行会出错。
2. 保存onnx模型
def saveONNX(model, filepath, c, h, w): model = model.cuda() dummy_input = torch.randn(1, c, h, w, device='cuda') torch.onnx.export(model, dummy_input, filepath, verbose=True)
3. 创建tensorrt引擎
def build_engine(onnx_file_path): TRT_LOGGER = trt.Logger(trt.Logger.WARNING) # INFO # For more information on TRT basics, refer to the introductory samples. with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, TRT_LOGGER) as parser: if builder.platform_has_fast_fp16: print('this card support fp16') if builder.platform_has_fast_int8: print('this card support int8') builder.max_workspace_size = 1 << 30 with open(onnx_file_path, 'rb') as model: parser.parse(model.read()) return builder.build_cuda_engine(network) # This function builds an engine from a Caffe model. def build_engine_int8(onnx_file_path, calib): TRT_LOGGER = trt.Logger() with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, TRT_LOGGER) as parser: # We set the builder batch size to be the same as the calibrator's, as we use the same batches # during inference. Note that this is not required in general, and inference batch size is # independent of calibration batch size. builder.max_batch_size = 1 # calib.get_batch_size() builder.max_workspace_size = 1 << 30 builder.int8_mode = True builder.int8_calibrator = calib with open(onnx_file_path, 'rb') as model: parser.parse(model.read()) # , dtype=trt.float32 return builder.build_cuda_engine(network)
4. 保存及载入引擎
def save_engine(engine, engine_dest_path): buf = engine.serialize() with open(engine_dest_path, 'wb') as f: f.write(buf) def load_engine(engine_path): TRT_LOGGER = trt.Logger(trt.Logger.WARNING) # INFO with open(engine_path, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime: return runtime.deserialize_cuda_engine(f.read())
5. 分配缓冲区
class HostDeviceMem(object): def __init__(self, host_mem, device_mem): self.host = host_mem self.device = device_mem def __str__(self): return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device) def __repr__(self): return self.__str__() def allocate_buffers(engine): inputs = [] outputs = [] bindings = [] stream = cuda.Stream() for binding in engine: dtype = trt.nptype(engine.get_binding_dtype(binding)) # Allocate host and device buffers host_mem = cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size, dtype) device_mem = cuda.mem_alloc(host_mem.nbytes) # Append the device buffer to device bindings. bindings.append(int(device_mem)) # Append to the appropriate list. if engine.binding_is_input(binding): inputs.append(HostDeviceMem(host_mem, device_mem)) else: outputs.append(HostDeviceMem(host_mem, device_mem)) return inputs, outputs, bindings, stream
6. 前向推断
def do_inference(context, bindings, inputs, outputs, stream, batch_size=1): # Transfer input data to the GPU. [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs] # Run inference. context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle) # Transfer predictions back from the GPU. [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs] # Synchronize the stream stream.synchronize() # Return only the host outputs. return [out.host for out in outputs]
7. 矫正(Calibrator)
使用tensorrt的int8时,需要矫正。具体可参见test_onnx_int8及calibrator.py。
8. 具体的推断代码
img_numpy = img.ravel().astype(np.float32) np.copyto(inputs[0].host, img_numpy) output = do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream) output = [np.reshape(stage_i, (10)) for stage_i in output] # 有多个输出时遍历
9. 代码分析
程序中主要包括下面6个函数。
test_pytorch() # 测试pytorch模型的代码 export_onnx() # 导出pytorch模型到onnx模型 test_onnx_fp32() # 测试tensorrt的fp32模型(有保存引擎的代码) test_onnx_fp32_engine() # 测试tensorrt的fp32引擎的代码 test_onnx_int8() # 测试tensorrt的int8模型(有保存引擎的代码) test_onnx_int8_engine() # 测试tensorrt的int8引擎的代码
10. 说明
9的部分函数中,最开始有一句:
torch.load('mnist_cnn_3.pth') # 如果结果不对,加上这句话
因为有时候会碰到,不使用这句话,直接运行代码时,结果完全不正确;加上这句话之后,结果正确了。
具体原因为找到。。。也就先记在这里吧。
posted on 2019-08-10 16:48 darkknightzh 阅读(15455) 评论(2) 编辑 收藏 举报