resnet18下载与保存,转换为ONNX模型,导出 .wts 格式的权重文件
1.download and save to 'resnet18.pth' file:
import torch from torch import nn from torch.nn import functional as F import torchvision def main(): print('cuda device count: ', torch.cuda.device_count()) net = torchvision.models.resnet18(pretrained=True) #net.fc = nn.Linear(512, 2) net = net.to('cuda:0') net.eval() print(net) tmp = torch.ones(2, 3, 224, 224).to('cuda:0') out = net(tmp) print('resnet18 out:', out.shape) torch.save(net, "resnet18.pth") if __name__ == '__main__': main()
this 'resnet18.pth' file contains the model structure and weights.
2.load the .pth file and transform it to ONNX format:
import torch def main(): model = torch.load('resnet18.pth') # model.eval() inputs = torch.randn(1,3,224,224) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") inputs = inputs.to(device) torch.onnx.export(model,inputs, 'resnet18_trtpose.onnx',training=2) if __name__ == '__main__': main()
3.load and read the .pth file, extract the weights of the model to a .wts file
import torch from torch import nn import torchvision import os import struct from torchsummary import summary def main(): print('cuda device count: ', torch.cuda.device_count()) net = torch.load('resnet18.pth') net = net.to('cuda:0') net.eval() print('model: ', net) #print('state dict: ', net.state_dict().keys()) tmp = torch.ones(1, 3, 224, 224).to('cuda:0') print('input: ', tmp) out = net(tmp) print('output:', out) summary(net, (3,224,224)) #return f = open("resnet18.wts", 'w') f.write("{}\n".format(len(net.state_dict().keys()))) for k,v in net.state_dict().items(): print('key: ', k) print('value: ', v.shape) vr = v.reshape(-1).cpu().numpy() f.write("{} {}".format(k, len(vr))) for vv in vr: f.write(" ") f.write(struct.pack(">f", float(vv)).hex()) f.write("\n") if __name__ == '__main__': main()