模型超参数基本都没改,测试时加载模型报模型结构不匹配,设置模糊加载模型即:model.load_state_dict(torch.load(model_path), strict=Fasle),但效果出奇的差
原因
多卡训练;单卡模糊加载进行测试。
训练时,通过torch.nn.DataParallel(self.model)进行多卡并行训练;测试时,用单卡模糊加载保存的模型权重,很多模型参数都没有加载成功,自然会导致测试效果很差。
解决方法
测试时,使用多卡加载模型时,删掉'module.'前缀;或者用单卡加载模型进行测试。
# 删掉'module.'前缀
model_cascade1.load_state_dict(get_loaded_dict(weight_c1), strict=True)
def get_loaded_dict(weight_path):
state_dict = torch.load(weight_path)
# 检查是否有 'module.' 前缀
has_module_prefix = any(key.startswith('module.') for key in state_dict.keys())
if has_module_prefix:
print("Loaded weight was from multi-GPU run. Removing 'module.' prefixes.")
new_state_dict = {}
for key, value in state_dict.items():
if key.startswith('module.'):
new_key = key[len('module.'):]
new_state_dict[new_key] = value
else:
new_state_dict[key] = value
return new_state_dict
else:
print("Loaded weight was not from multi-GPU run. No action needed.")
return state_dict
解决效果
改动前:
改动后:
收获
我之前一直以为strict=Fasle对模型效果影响不大,这次总算知道影响有多大了。
curie.