Tensorflow的MobileNetV1参数迁移到pytorch上并保存

因为放弃tensorflow超级久了,也不想再去用它,因为明明很简单用pytorch十几行作出的代码,tensorflow的版本完全看不懂,我这个菜鸡还是老老实实刨地吧。mobilenet的代码网上一大堆,我把我写的贴出来吧,论文简单易读,连我这种英语渣渣两天就看完了。

mobelnet的代码如下。

import torch.nn as nn
import torch
class Conv_bn(nn.Module):
    def __init__(self,inp,oup,stride):
        super(Conv_bn, self).__init__()
        self.convBn=nn.Sequential(
            nn.Conv2d(inp,oup,3,stride,1,bias=False),
            nn.BatchNorm2d(oup),
            nn.ReLU(inplace=True)
        )
    def forward(self,x):
        out=self.convBn(x)
        return out

class Conv_depth(nn.Module):
    def __init__(self,inp,oup,stride):
        super(Conv_depth, self).__init__()
        self.convDepthwise=nn.Sequential(
            nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
            nn.BatchNorm2d(inp),
            nn.ReLU(inplace=True),

            nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
            nn.BatchNorm2d(oup),
            nn.ReLU(inplace=True)
        )
    def forward(self,x):
        out=self.convDepthwise(x)
        return out


class MobileNet(nn.Module):
    def __init__(self):
        super(MobileNet, self).__init__()
        self.mobelnet=nn.Sequential(
            Conv_bn(3, 32, 2),
            Conv_depth(32, 64, 1),
            Conv_depth(64, 128, 2),
            Conv_depth(128, 128, 1),
            Conv_depth(128, 256, 2),
            Conv_depth(256, 256, 1),
            Conv_depth(256, 512, 2),
            Conv_depth(512, 512, 1),
            Conv_depth(512, 512, 1),
            Conv_depth(512, 512, 1),
            Conv_depth(512, 512, 1),
            Conv_depth(512, 512, 1),
            Conv_depth(512, 1024, 2),
            Conv_depth(1024, 1024, 1),
            nn.AvgPool2d(7),)

        self.fc = nn.Linear(1024, 1000)

    # 网络的前向过程
    def forward(self, x):
        x=self.mobelnet(x)
        x=x.view(-1, 1024)
        x=self.fc(x)
        return x

妈呀,简单吧,但是你不知道tensorflow的版本有多长啊。

然后转参数把我难住了,没做过,参考了 https://www.jianshu.com/p/0a61caeb693b 这位同学的moielnetV3版本的改法,但是我真的不懂他那个字典怎么定义的,我每次model.层名 就开始给我出红杠杠,报错,我估计可能是他把层都封装成了对象吧,如果有懂的同学希望能给我讲讲哈。我贴我自己的代码吧。

import json
import tensorflow as tf
import os
from MobileNet.mobilenet_v1 import MobileNet
import numpy as np
import torch
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
CHECKPOINT_PATH='/Users/wenyu/Desktop/TorchProject/MobileNet/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt'

# write the json file
def new_dict(checkpoint_path,json_path):
    reader=tf.compat.v1.train.NewCheckpointReader(checkpoint_path)
    weights_shape =reader.get_variable_to_shape_map()
    print('the layer',weights_shape['MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean'])
    length=len(weights_shape['MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean'])
    # print(length)
    if not os.path.exists(json_path):
        weights_small = {n: 1 for (n, _) in reader.get_variable_to_shape_map().items()}
        keys_list=list(weights_small.keys())
        for key_ in keys_list:
            if "/ExponentialMovingAverage" in key_:
                del weights_small[key_]
            elif "/RMSProp" in key_:
                del weights_small[key_]
        with open(json_path, 'w') as writer:
            json.dump(weights_small, fp=writer, sort_keys=True)
    else:
        print('the json file has been write!')

# get convBn_dict
def get_convbn_convert_dict(layer_num):
    convert_dict={
        'mobelnet.'+str(layer_num)+'.convBn.0.weight':'MobilenetV1/Conv2d_'+str(layer_num)+'/weights',
        'mobelnet.'+str(layer_num)+'.convBn.1.weight':'MobilenetV1/Conv2d_'+str(layer_num)+'/BatchNorm/beta',
        'mobelnet.'+str(layer_num)+'.convBn.1.bias':'MobilenetV1/Conv2d_'+str(layer_num)+'/BatchNorm/gamma',
        'mobelnet.'+str(layer_num)+'.convBn.1.running_mean':'MobilenetV1/Conv2d_'+str(layer_num)+'/BatchNorm/moving_mean',
        'mobelnet.'+str(layer_num)+'.convBn.1.running_var':'MobilenetV1/Conv2d_'+str(layer_num)+'/BatchNorm/moving_variance'
    }
    return convert_dict

# get depthWise_dict
def get_dpwise_convert_dict(layer_num):
    convert_dict={
        'mobelnet.'+str(layer_num)+'.convDepthwise.0.weight':
            'MobilenetV1/Conv2d_'+str(layer_num)+'_depthwise/depthwise_weights',
        'mobelnet.'+str(layer_num)+'.convDepthwise.1.weight':
            'MobilenetV1/Conv2d_'+str(layer_num)+'_depthwise/BatchNorm/beta',
        'mobelnet.'+str(layer_num)+'.convDepthwise.1.bias':
            'MobilenetV1/Conv2d_'+str(layer_num)+'_depthwise/BatchNorm/gamma',
        'mobelnet.'+str(layer_num)+'.convDepthwise.1.running_mean':
            'MobilenetV1/Conv2d_'+str(layer_num)+'_depthwise/BatchNorm/moving_mean',
        'mobelnet.'+str(layer_num)+'.convDepthwise.1.running_var':
            'MobilenetV1/Conv2d_'+str(layer_num)+'_depthwise/BatchNorm/moving_variance',
        'mobelnet.'+str(layer_num)+'.convDepthwise.3.weight':
            'MobilenetV1/Conv2d_'+str(layer_num)+'_pointwise/weights',
        'mobelnet.'+str(layer_num)+'.convDepthwise.4.weight':
            'MobilenetV1/Conv2d_'+str(layer_num)+'_pointwise/BatchNorm/beta',
        'mobelnet.' + str(layer_num) + '.convDepthwise.4.bias':
            'MobilenetV1/Conv2d_' + str(layer_num) + '_pointwise/BatchNorm/gamma',
        'mobelnet.' + str(layer_num) + '.convDepthwise.4.running_mean':
            'MobilenetV1/Conv2d_' + str(layer_num) + '_pointwise/BatchNorm/moving_mean',
        'mobelnet.' + str(layer_num) + '.convDepthwise.4.running_var':
            'MobilenetV1/Conv2d_' + str(layer_num) + '_pointwise/BatchNorm/moving_variance'
    }
    return convert_dict

# get conversion_dict
def get_model_dict(layers_num):
    merge = lambda dict1, dict2: {**dict1, **dict2}
    conversion_table = {}
    convBn_dict=get_convbn_convert_dict(0)
    conversion_table=merge(conversion_table,convBn_dict)
    for i in range(1,layers_num):
        dpWise_dict=get_dpwise_convert_dict(i)
        conversion_table=merge(conversion_table,dpWise_dict)
    # load_parameter(CHECKPOINT_PATH,conversion_table)
    return conversion_table
def write_json(conversion_table,json_path):
    if not os.path.exists(json_path):
        with open(json_path, 'w') as writer:
            json.dump(conversion_table, fp=writer, sort_keys=True)
    else:
        print('the conversion table has been wirten!')

def load_parameter(conversion_table):
    module=MobileNet()
    original_model_dict=module.state_dict()
    pth_list=list(conversion_table.keys())
    ckpt_list=list(conversion_table.values())
    assert len(pth_list)==len(ckpt_list) ,('the length is not right!')
    reader=tf.compat.v1.train.NewCheckpointReader(CHECKPOINT_PATH)
    for i,ckpt_name in enumerate(ckpt_list):
        ckpt_name_value=tf.compat.v1.train.load_variable(CHECKPOINT_PATH,ckpt_name)
        if 'Conv2d' in ckpt_name and 'weights' in ckpt_name:
            ckpt_name_value=np.transpose(ckpt_name_value,(3,2,0,1))
            if 'depthwise' in ckpt_name:
                ckpt_name_value=np.transpose(ckpt_name_value,(1,0,2,3))
        elif 'BatchNorm' in ckpt_name and ckpt_name_value.ndim==1:
            # ckpt_name_value=np.transpose(ckpt_name_value)
            ckpt_name_value=ckpt_name_value
        pytorch_dict_key=pth_list[i]
        original_model_dict[pytorch_dict_key].data=torch.from_numpy(ckpt_name_value)

    torch.save(original_model_dict,'/Users/wenyu/Desktop/TorchProject/MobileNet/tf_to_torch.pth')
    return original_model_dict

if __name__ == '__main__':
    conversion_table=get_model_dict(14)
    dic_mobel=load_parameter(conversion_table)
    print(dic_mobel['mobelnet.1.convDepthwise.0.weight'].shape)

其中核心就在最后两个函数,可能代码看起来很简单,但是我想了好久要怎么做,第一次做很不熟练,但是通过这次巩固了很多numpy,tensor还有字典的基本知识,很充实。有问题可以在博客下面留言。

posted @ 2020-05-05 22:48  daremosiranaihana  阅读(653)  评论(2编辑  收藏  举报