RS小生

导航

利用resnet预训练权重,出现“bn1.num_batches_tracked”或者“layer.0.bn1.num_batches_tracked" 的解决办法

报错的原因在于Pytorch0.4之后,在BN层后新增加了track_running_stats这个参数。

在调用预训练参数模型是,官方给定的预训练模型是在pytorch0.4之前,因此,调用预训练参数时,需要过滤掉“num_batches_tracked”。

以resnet50为例:

为了加载不同层的权重,采用两个函数,如下:load_partial_param用于加载layer1, layer2, layer3, layer4的权重权重,load_specific_param用于加载第一层的权重参数。

 

为了避免“num_batches_tracked”报错,采用下面的代码即可,更改部分为红色字体(方法简单,但可以满足要求)。

 

 def load_partial_param(self, state_dict, model_index, model_path):
     param_dict = torch.load(model_path)
     for i in state_dict:
         key = 'layer{}.'.format(model_index)+i
         if 'tracked' in key[-7:]:
             continue
         state_dict[i].copy_(param_dict[key])
     del param_dict

 

def load_specific_param(self, state_dict, param_name, model_path):
    param_dict = torch.load(model_path)
    for i in state_dict:
        key = param_name + '.' + i
        if 'num_batches_tracked' in key:
            continue
        state_dict[i].copy_(param_dict[key])
    del param_dict

 

 

 

posted on 2020-08-20 19:31  RS小生  阅读(3616)  评论(0编辑  收藏  举报