pytorch转onnx

一个输入:

    input_tensor = torch.randn([1, 3, 256, 512])
    print ("Exporting to ONNX: ", onnx_save_name)
    torch_onnx_out = torch.onnx.export(model, input_tensor, onnx_save_name, 
                        export_params=True,
                        verbose=True, 
                        input_names=['label'],
                        output_names=["synthesized"],
                        opset_version=11)

多个输入:

    input_tensor = torch.randn([1, 3, 256, 512])
    mask_tensor = torch.randn([1, 3, 256, 512])
    print ("Exporting to ONNX: ", onnx_save_name)
    torch_onnx_out = torch.onnx.export(model, (input_tensor,mask_tensor), onnx_save_name, 
                        export_params=True,
                        verbose=True, 
                        input_names=['label','mask'],
                        output_names=["synthesized"],
                        opset_version=11)

 

posted @ 2020-11-04 14:01  皮卡皮卡妞  阅读(5663)  评论(0编辑  收藏  举报