Tensorflow的MobileNetV1参数迁移到pytorch上并保存
因为放弃tensorflow超级久了,也不想再去用它,因为明明很简单用pytorch十几行作出的代码,tensorflow的版本完全看不懂,我这个菜鸡还是老老实实刨地吧。mobilenet的代码网上一大堆,我把我写的贴出来吧,论文简单易读,连我这种英语渣渣两天就看完了。
mobelnet的代码如下。
import torch.nn as nn import torch class Conv_bn(nn.Module): def __init__(self,inp,oup,stride): super(Conv_bn, self).__init__() self.convBn=nn.Sequential( nn.Conv2d(inp,oup,3,stride,1,bias=False), nn.BatchNorm2d(oup), nn.ReLU(inplace=True) ) def forward(self,x): out=self.convBn(x) return out class Conv_depth(nn.Module): def __init__(self,inp,oup,stride): super(Conv_depth, self).__init__() self.convDepthwise=nn.Sequential( nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), nn.BatchNorm2d(inp), nn.ReLU(inplace=True), nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), nn.ReLU(inplace=True) ) def forward(self,x): out=self.convDepthwise(x) return out class MobileNet(nn.Module): def __init__(self): super(MobileNet, self).__init__() self.mobelnet=nn.Sequential( Conv_bn(3, 32, 2), Conv_depth(32, 64, 1), Conv_depth(64, 128, 2), Conv_depth(128, 128, 1), Conv_depth(128, 256, 2), Conv_depth(256, 256, 1), Conv_depth(256, 512, 2), Conv_depth(512, 512, 1), Conv_depth(512, 512, 1), Conv_depth(512, 512, 1), Conv_depth(512, 512, 1), Conv_depth(512, 512, 1), Conv_depth(512, 1024, 2), Conv_depth(1024, 1024, 1), nn.AvgPool2d(7),) self.fc = nn.Linear(1024, 1000) # 网络的前向过程 def forward(self, x): x=self.mobelnet(x) x=x.view(-1, 1024) x=self.fc(x) return x
妈呀,简单吧,但是你不知道tensorflow的版本有多长啊。
然后转参数把我难住了,没做过,参考了 https://www.jianshu.com/p/0a61caeb693b 这位同学的moielnetV3版本的改法,但是我真的不懂他那个字典怎么定义的,我每次model.层名 就开始给我出红杠杠,报错,我估计可能是他把层都封装成了对象吧,如果有懂的同学希望能给我讲讲哈。我贴我自己的代码吧。
import json import tensorflow as tf import os from MobileNet.mobilenet_v1 import MobileNet import numpy as np import torch import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' CHECKPOINT_PATH='/Users/wenyu/Desktop/TorchProject/MobileNet/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt' # write the json file def new_dict(checkpoint_path,json_path): reader=tf.compat.v1.train.NewCheckpointReader(checkpoint_path) weights_shape =reader.get_variable_to_shape_map() print('the layer',weights_shape['MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean']) length=len(weights_shape['MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean']) # print(length) if not os.path.exists(json_path): weights_small = {n: 1 for (n, _) in reader.get_variable_to_shape_map().items()} keys_list=list(weights_small.keys()) for key_ in keys_list: if "/ExponentialMovingAverage" in key_: del weights_small[key_] elif "/RMSProp" in key_: del weights_small[key_] with open(json_path, 'w') as writer: json.dump(weights_small, fp=writer, sort_keys=True) else: print('the json file has been write!') # get convBn_dict def get_convbn_convert_dict(layer_num): convert_dict={ 'mobelnet.'+str(layer_num)+'.convBn.0.weight':'MobilenetV1/Conv2d_'+str(layer_num)+'/weights', 'mobelnet.'+str(layer_num)+'.convBn.1.weight':'MobilenetV1/Conv2d_'+str(layer_num)+'/BatchNorm/beta', 'mobelnet.'+str(layer_num)+'.convBn.1.bias':'MobilenetV1/Conv2d_'+str(layer_num)+'/BatchNorm/gamma', 'mobelnet.'+str(layer_num)+'.convBn.1.running_mean':'MobilenetV1/Conv2d_'+str(layer_num)+'/BatchNorm/moving_mean', 'mobelnet.'+str(layer_num)+'.convBn.1.running_var':'MobilenetV1/Conv2d_'+str(layer_num)+'/BatchNorm/moving_variance' } return convert_dict # get depthWise_dict def get_dpwise_convert_dict(layer_num): convert_dict={ 'mobelnet.'+str(layer_num)+'.convDepthwise.0.weight': 'MobilenetV1/Conv2d_'+str(layer_num)+'_depthwise/depthwise_weights', 'mobelnet.'+str(layer_num)+'.convDepthwise.1.weight': 'MobilenetV1/Conv2d_'+str(layer_num)+'_depthwise/BatchNorm/beta', 'mobelnet.'+str(layer_num)+'.convDepthwise.1.bias': 'MobilenetV1/Conv2d_'+str(layer_num)+'_depthwise/BatchNorm/gamma', 'mobelnet.'+str(layer_num)+'.convDepthwise.1.running_mean': 'MobilenetV1/Conv2d_'+str(layer_num)+'_depthwise/BatchNorm/moving_mean', 'mobelnet.'+str(layer_num)+'.convDepthwise.1.running_var': 'MobilenetV1/Conv2d_'+str(layer_num)+'_depthwise/BatchNorm/moving_variance', 'mobelnet.'+str(layer_num)+'.convDepthwise.3.weight': 'MobilenetV1/Conv2d_'+str(layer_num)+'_pointwise/weights', 'mobelnet.'+str(layer_num)+'.convDepthwise.4.weight': 'MobilenetV1/Conv2d_'+str(layer_num)+'_pointwise/BatchNorm/beta', 'mobelnet.' + str(layer_num) + '.convDepthwise.4.bias': 'MobilenetV1/Conv2d_' + str(layer_num) + '_pointwise/BatchNorm/gamma', 'mobelnet.' + str(layer_num) + '.convDepthwise.4.running_mean': 'MobilenetV1/Conv2d_' + str(layer_num) + '_pointwise/BatchNorm/moving_mean', 'mobelnet.' + str(layer_num) + '.convDepthwise.4.running_var': 'MobilenetV1/Conv2d_' + str(layer_num) + '_pointwise/BatchNorm/moving_variance' } return convert_dict # get conversion_dict def get_model_dict(layers_num): merge = lambda dict1, dict2: {**dict1, **dict2} conversion_table = {} convBn_dict=get_convbn_convert_dict(0) conversion_table=merge(conversion_table,convBn_dict) for i in range(1,layers_num): dpWise_dict=get_dpwise_convert_dict(i) conversion_table=merge(conversion_table,dpWise_dict) # load_parameter(CHECKPOINT_PATH,conversion_table) return conversion_table def write_json(conversion_table,json_path): if not os.path.exists(json_path): with open(json_path, 'w') as writer: json.dump(conversion_table, fp=writer, sort_keys=True) else: print('the conversion table has been wirten!') def load_parameter(conversion_table): module=MobileNet() original_model_dict=module.state_dict() pth_list=list(conversion_table.keys()) ckpt_list=list(conversion_table.values()) assert len(pth_list)==len(ckpt_list) ,('the length is not right!') reader=tf.compat.v1.train.NewCheckpointReader(CHECKPOINT_PATH) for i,ckpt_name in enumerate(ckpt_list): ckpt_name_value=tf.compat.v1.train.load_variable(CHECKPOINT_PATH,ckpt_name) if 'Conv2d' in ckpt_name and 'weights' in ckpt_name: ckpt_name_value=np.transpose(ckpt_name_value,(3,2,0,1)) if 'depthwise' in ckpt_name: ckpt_name_value=np.transpose(ckpt_name_value,(1,0,2,3)) elif 'BatchNorm' in ckpt_name and ckpt_name_value.ndim==1: # ckpt_name_value=np.transpose(ckpt_name_value) ckpt_name_value=ckpt_name_value pytorch_dict_key=pth_list[i] original_model_dict[pytorch_dict_key].data=torch.from_numpy(ckpt_name_value) torch.save(original_model_dict,'/Users/wenyu/Desktop/TorchProject/MobileNet/tf_to_torch.pth') return original_model_dict if __name__ == '__main__': conversion_table=get_model_dict(14) dic_mobel=load_parameter(conversion_table) print(dic_mobel['mobelnet.1.convDepthwise.0.weight'].shape)
其中核心就在最后两个函数,可能代码看起来很简单,但是我想了好久要怎么做,第一次做很不熟练,但是通过这次巩固了很多numpy,tensor还有字典的基本知识,很充实。有问题可以在博客下面留言。