利用resnet预训练权重,出现“bn1.num_batches_tracked”或者“layer.0.bn1.num_batches_tracked" 的解决办法
报错的原因在于Pytorch0.4之后,在BN层后新增加了track_running_stats这个参数。
在调用预训练参数模型是,官方给定的预训练模型是在pytorch0.4之前,因此,调用预训练参数时,需要过滤掉“num_batches_tracked”。
以resnet50为例:
为了加载不同层的权重,采用两个函数,如下:load_partial_param用于加载layer1, layer2, layer3, layer4的权重权重,load_specific_param用于加载第一层的权重参数。
为了避免“num_batches_tracked”报错,采用下面的代码即可,更改部分为红色字体(方法简单,但可以满足要求)。
def load_partial_param(self, state_dict, model_index, model_path): param_dict = torch.load(model_path) for i in state_dict: key = 'layer{}.'.format(model_index)+i if 'tracked' in key[-7:]: continue state_dict[i].copy_(param_dict[key]) del param_dict
def load_specific_param(self, state_dict, param_name, model_path): param_dict = torch.load(model_path) for i in state_dict: key = param_name + '.' + i if 'num_batches_tracked' in key: continue state_dict[i].copy_(param_dict[key]) del param_dict