模型超参数基本都没改,测试时加载模型报模型结构不匹配,设置模糊加载模型即:model.load_state_dict(torch.load(model_path), strict=Fasle),但效果出奇的差

原因

多卡训练;单卡模糊加载进行测试。
训练时,通过torch.nn.DataParallel(self.model)进行多卡并行训练;测试时,用单卡模糊加载保存的模型权重,很多模型参数都没有加载成功,自然会导致测试效果很差。

解决方法

测试时,使用多卡加载模型时,删掉'module.'前缀;或者用单卡加载模型进行测试。

# 删掉'module.'前缀
model_cascade1.load_state_dict(get_loaded_dict(weight_c1), strict=True)
def get_loaded_dict(weight_path):
    state_dict = torch.load(weight_path)
        
    # 检查是否有 'module.' 前缀
    has_module_prefix = any(key.startswith('module.') for key in state_dict.keys())
    if has_module_prefix:
        print("Loaded weight was from multi-GPU run. Removing 'module.' prefixes.")
        new_state_dict = {}
        for key, value in state_dict.items():
            if key.startswith('module.'):
                new_key = key[len('module.'):]
                new_state_dict[new_key] = value
            else:
                new_state_dict[key] = value
                
        return new_state_dict
    else:
        print("Loaded weight was not from multi-GPU run. No action needed.")
        return state_dict

解决效果

改动前:

改动后:

收获

我之前一直以为strict=Fasle对模型效果影响不大,这次总算知道影响有多大了。

posted @ 2023-08-15 17:28  Kurie  阅读(162)  评论(0编辑  收藏  举报