pytorch中bin模型文件转onnx遇到的问题
pytorch中bin模型文件转onnx遇到的问题
1 常规做法
import os import numpy as np from transformers import GPT2LMHeadModel import torch localfile = r"C:\Users\min_ppl_model" model = GPT2LMHeadModel.from_pretrained(localfile) # 输入shape 为 1,50 其中1 为bs 50为固定输入长度,前面做量化时规定的。 # 确定输入shape batch_size = 1 seq = 50 example_input = torch.randn(batch_size, seq) print(example_input) # 保存文件 save_onnx_dir = r"layer20_gpt2" os.makedirs(save_onnx_dir, exist_ok=True) save_onnx_model = r"layer20_gpt2\layer20_gpt2.onnx" # 转onnx with torch.no_grad(): torch.onnx.export(model, example_input, save_onnx_model, opset_version=11, input_names=["inp"], output_names=["opt"]) print("-----end-------")
其中
example_input = torch.randn(batch_size, seq, dtype=torch.float32)
最为关键一步
类型为:
(Pdb) type(example_input)
<class 'torch.Tensor'>
打印为:
tensor([[-2.0319, -0.4021, 0.0092, 1.4100, -0.2214, 0.6954, 0.1764, 0.2111,
-0.4725, -0.7527, -0.0766, -1.4510, -0.2528, -1.4077, -0.9340, 0.2309,
0.5564, -0.0498, -0.7499, -1.8176, -0.8981, 1.3997, 0.2904, -0.5024,
0.8392, -0.2341, -0.5459, -1.0992, 0.0211, 0.3346, -0.5087, -0.6159,
1.3256, -0.0423, -0.0764, 1.0469, -1.7328, -1.3470, -0.3346, 0.2129,
1.1073, -0.7503, 0.3968, -0.1374, 0.6514, -0.8763, -1.1972, -1.7750,
-0.9977, -2.2836]])
报错为:RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)
原因为 embedding 为int。
直接更换为整型:
example_input = torch.randint(1,5, size=(batch_size, seq))
2 先用np 再转torch.tensor
example_input = torch.from_numpy(np.random.randint(low=6,high=10,size=(1,50),dtype=np.int32))
可以稍微指定一个整数范围
(Pdb) type(example_input)
<class 'torch.Tensor'>
打印信息:
tensor([[6, 9, 9, 9, 8, 6, 6, 6, 8, 8, 9, 9, 9, 7, 9, 9, 7, 8, 8, 7, 9, 9, 9, 6,
7, 6, 8, 8, 9, 9, 8, 7, 7, 8, 8, 9, 6, 9, 7, 9, 8, 6, 9, 7, 7, 7, 9, 6,
7, 8]], dtype=torch.int32)
则可以成功执行。
总结:
1 能先用int先用int, 要求是float就先用float, 随机生成数据。
2 如果定义数据是dtype=torch.int32 如果带了,则打印数据时也会带, 否则不显示。
tensor([[6, 9, 9, 9, 8, 6, 6, 6, 8, 8, 9, 9, 9, 7, 9, 9, 7, 8, 8, 7, 9, 9, 9, 6,
7, 6, 8, 8, 9, 9, 8, 7, 7, 8, 8, 9, 6, 9, 7, 9, 8, 6, 9, 7, 7, 7, 9, 6,
7, 8]], dtype=torch.int32)