模型超参数基本都没改,测试时加载模型报模型结构不匹配,设置模糊加载模型即:model.load_state_dict(torch.load(model_path), strict=Fasle),但效果出奇的差
原因
多卡训练;单卡模糊加载进行测试。
训练时,通过torch.nn.DataParallel(self.model)进行多卡并行训练;测试时,用单卡模糊加载保存的模型权重,很多模型参数都没有加载成功,自然会导致测试效果很差。
解决方法
测试时,使用多卡加载模型时,删掉'module.'前缀;或者用单卡加载模型进行测试。
# 删掉'module.'前缀
model_cascade1.load_state_dict(get_loaded_dict(weight_c1), strict=True)
def get_loaded_dict(weight_path):
state_dict = torch.load(weight_path)
# 检查是否有 'module.' 前缀
has_module_prefix = any(key.startswith('module.') for key in state_dict.keys())
if has_module_prefix:
print("Loaded weight was from multi-GPU run. Removing 'module.' prefixes.")
new_state_dict = {}
for key, value in state_dict.items():
if key.startswith('module.'):
new_key = key[len('module.'):]
new_state_dict[new_key] = value
else:
new_state_dict[key] = value
return new_state_dict
else:
print("Loaded weight was not from multi-GPU run. No action needed.")
return state_dict
解决效果
改动前:
改动后:
收获
我之前一直以为strict=Fasle对模型效果影响不大,这次总算知道影响有多大了。
curie.
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· .NET10 - 预览版1新功能体验(一)