将pb模型参数提取转成torch模型

  1 import tensorflow as tf
  2 import onnx
  3 import onnxsim
  4 import numpy as np
  5 import torch
  6 from model.facedetector_model import mobilenetv2_yolov3
  7 
  8 #提取pb模型中的参数
  9 def extract_params_from_pb():
 10     constant_values = {}
 11     with tf.compat.v1.Session() as sess:
 12         with tf.io.gfile.GFile('model/FaceDetector.pb', 'rb') as f:
 13             graph_def = tf.compat.v1.GraphDef()
 14             graph_def.ParseFromString(f.read())
 15         sess.graph.as_default()
 16         tf.import_graph_def(graph_def, name='')
 17         # # input
 18         # input_x = sess.graph.get_tensor_by_name('input/input_data:0')
 19         # # output
 20         # output = sess.graph.get_tensor_by_name('pred_bbox/Reshape:0')
 21         # sess.run(output, feed_dict={'input/input_data:0': inputimage})
 22 
 23         constant_ops = [op for op in sess.graph.get_operations()]#[op for op in sess.graph.get_operations() if op.type == "Const"]
 24         for constant_op in constant_ops:
 25             if constant_op.op_def.name == "Const":
 26                 if "Shape" in constant_op.name or "pred" in constant_op.name:
 27                     continue
 28                 constant_values[constant_op.name] = sess.run(constant_op.outputs[0])
 29     return constant_values
 30 
 31 #过滤提取出来的params
 32 def filter_params(constant_values):
 33     total = 0
 34     prompt = []
 35     res = {}
 36     forbidden = ['shape','stack']
 37     
 38     for k,v in constant_values.items():
 39         # filtering some by checking ndim and name
 40         if v.ndim<1: continue
 41         if v.ndim==1:
 42             token = k.split(r'/')[-1]
 43             flag = False
 44             for word in forbidden:
 45                 if token.find(word)!=-1:
 46                     flag = True
 47                     break
 48             if flag:
 49                 continue
 50 
 51         shape = v.shape
 52         cnt = 1
 53         for dim in shape:
 54             cnt *= dim
 55         prompt.append('{} with shape {} has {}'.format(k, shape, cnt))
 56         res[k] = v
 57         print(prompt[-1])
 58         total += cnt
 59     prompt.append('totaling {}'.format(total))
 60     # print(prompt[-1])
 61     return res
 62 
 63 #将Tensorflow的张量转换成PyTorch的张量
 64 def trans_tensor_pb2pth(k,a):
 65  
 66     v = tf.convert_to_tensor(a).numpy()
 67     # tensorflow weights to pytorch weights
 68     if len(v.shape) == 4:
 69         if "depthwise_weights" in k:#防止深度可分离卷积
 70             return np.ascontiguousarray(v.transpose(2,3,0,1))
 71         return np.ascontiguousarray(v.transpose(3,2,0,1))
 72     elif len(v.shape) == 2:
 73         return np.ascontiguousarray(v.transpose())
 74     return v
 75 
 76 #将pb的对应params名字转换为pth对应参数名
 77 def trans_name_pb2pth(trans_weights):
 78     model_dict = {}
 79     for name,para in trans_weights.items():
 80         name = name.replace('/',".")
 81         
 82         if "MobilenetV2.Conv" in name:#处理MobilenetV2.Conv
 83             name = name.replace('weights',"0.weight")
 84             name = name.replace('BatchNorm',"1")
 85             name = name.replace('gamma',"weight")
 86             name = name.replace('beta',"bias")
 87             name = name.replace('moving_mean',"running_mean")
 88             name = name.replace('moving_variance',"running_var")
 89         elif "MobilenetV2.expanded_conv." in name:#处理MobilenetV2.expanded_conv.
 90             name = name.replace('depthwise.',"0.")
 91             name = name.replace('project',"1")
 92             name = name.replace('depthwise_weights',"0.weight")
 93             name = name.replace('weights',"0.weight")
 94             name = name.replace('BatchNorm',"1")
 95             name = name.replace('gamma',"weight")
 96             name = name.replace('beta',"bias")
 97             name = name.replace('moving_mean',"running_mean")
 98             name = name.replace('moving_variance',"running_var")
 99         elif "MobilenetV2.expanded_conv_" in name:#处理MobilenetV2.expanded_conv_*
100             name = name.replace('expand.',"0.")
101             name = name.replace('depthwise.',"1.")
102             name = name.replace('project',"2")
103             name = name.replace('depthwise_weights',"0.weight")
104             name = name.replace('weights',"0.weight")
105             name = name.replace('BatchNorm',"1")
106             name = name.replace('gamma',"weight")
107             name = name.replace('beta',"bias")
108             name = name.replace('moving_mean',"running_mean")
109             name = name.replace('moving_variance',"running_var")
110         elif "yolo-v3" in name:
111             if "bbox" in name:
112                 continue
113             name = name.replace('yolo-v3',"yolo_v3")
114             name = name.replace('weight',"0.weight")
115             name = name.replace('kernel',"weight")
116             name = name.replace('batch_normalization',"1")
117             name = name.replace('gamma',"weight")
118             name = name.replace('beta',"bias")
119             name = name.replace('moving_mean',"running_mean")
120             name = name.replace('moving_variance',"running_var")
121         print(name)
122         model_dict[name] = torch.Tensor(para)
123     return model_dict
124 
125 #将pb参数copy给pth模型
126 def copy_pbParams2pthParams():
127     constant_values = extract_params_from_pb()
128     TF_weights = filter_params(constant_values)
129     trans_weights = {k:trans_tensor_pb2pth(k,v) for (k, v) in TF_weights.items() }
130 
131     #创建pytorch模型
132     PyTorchModel = mobilenetv2_yolov3()
133     model_dict = trans_name_pb2pth(trans_weights)
134     # model_dict = PyTorchModel.state_dict()
135     # for name in model_dict.keys():
136     #     print(name)
137     PyTorchModel.load_state_dict(model_dict)
138     PyTorchModel.cuda().eval()
139     dummy_input = torch.rand(1,1,224,224,device="cuda").float() 
140     # out = PyTorchModel(dummy_input)
141     torch.onnx.export(PyTorchModel,dummy_input,"P3mNet.onnx",verbose = True,opset_version = 11)
142     print("====> Simplifying...")
143     model_opt,_ = onnxsim.simplify("P3mNet.onnx")
144     onnx.save(model_opt, 'P3mNet_sim.onnx')
145     print("onnx model simplify Ok!")
146 copy_pbParams2pthParams()

 

posted @ 2022-10-11 10:06  鲍曼小学生  阅读(321)  评论(1编辑  收藏  举报