UNet pytorch模型转ONNX模型完整code
1 import os 2 import torch 3 import numpy as np 4 from Unet import UNET 5 os.environ["CUDA_VISIBLE_DEVICE"] = "" 6 7 def main(): 8 demo = Demo(model_path="/xxx.pth.tar", output="pathto/xxx.onnx") 9 demo.inference() 10 check_onnx(onnx_pth="path to xxx.onnx") 11 12 13 14 #检查onnx模型 15 def check_onnx(onnx_pth): 16 import onnx 17 #load the ONNX model 18 model = onnx.load(onnx_pth) 19 #check the IR is well formed 20 onnx.checker.check_model(model) 21 #print a human readable representation of graph 22 print(onnx.helper.printable_graph(model.graph)) 23 24 class WrappedModel(torch.nn.Module): 25 def __init__(self,model): 26 super().__init__() 27 self.model =model 28 29 def forward(self,x):31 outs=self.model(x) 32 new_outs=torch.sigmoid(outs) 33 return new_outs 34 35 36 class Demo(): 37 def __init__(self,model_path,output): 38 self.model_path =model_path 39 self.output_path = output 40 41 def init_torch_tensor(self): 42 self.device = 'cpu'#torch.device('cpu') 43 torch.set_default_tensor_type('torch.FloatTensor') 44 #use gpu or not 45 # if torch.cuda.is_available(): 46 # self.device = torch.device('cuda') 47 # torch.set_default_tensor_type('torch.FloatTensor') 48 # else: 49 # self.device = torch.device('cpu') 50 # torch.set_default_tensor_type('torch.FloatTensor') 51 52 def init_model(self,in_channels,out_channels): 53 model = UNET(in_channels=in_channels, out_channels=out_channels).to(self.device)#to('cuda') 54 return model 55 56 def resume(self, model, path): 57 if not os.path.exists(path): 58 print("Checkpoint not found:" + path) 59 return 60 states = torch.load(path, map_location=self.device)# 61 model.load_state_dict(states["state_dict"],strict=False)#states有两个key_value"state_dict","optimizer" 62 63 model_sig = WrappedModel(model) 64 print("Resume from " + path) 65 return model_sig 66 67 def inference(self): 68 #use gpu or cpu 69 self.init_torch_tensor() 70 #加载网络模型 71 model = self.init_model(in_channels=3,out_channels=2) 72 model_sig=self.resume(model, self.model_path) 73 #设置model的模式 74 model_sig.eval() 75 #设置输入 76 img = np.random.randint(0,255, size=(512,512,3),dtype=np.uint8) 77 img = img.astype(np.float32) 78 img = img / 255#(img / 255. - 0.5)/0.5 79 img = img.transpose((2,0,1)) #C H W 80 img = torch.from_numpy(img).unsqueeze(0).float() 81 #img = torch.randn(1,3,512,512) 82 ''' 83 设置动态可变维度 84 KEY(str) - 必须是input_names或output_names指定的名称,用来指定哪个变量需要使用到动态尺寸。 85 VALUE(dict or list) - 如果是一个dict,dict中的key是变量的某个维度,dict中的value是我们给这个维度取的名称。如果是一个list,则list中的元素都表示此变量的某个维度。 86 ''' 87 dynamic_axes = {'input':{0: 'batch_size', 2: 'height', 3: 'width'}, 88 'output': {0:'batch_size', 2: 'height', 3: 'width'}} 89 with torch.no_grad(): 90 img = img.to(self.device) 91 torch.onnx.export(model_sig, img, self.output_path, input_names=['input'], 92 output_names=['output'], dynamic_axes=dynamic_axes, keep_initializers_as_inputs=False,export_params=True, 93 verbose=True, opset_version=11) 94 95 if __name__ == '__main__': 96 main()