pytorch学习007- -预训练中的权重加载(完全导入,部分导入)
更新
2022.04.12更新
导入权重的用法相当普遍,但是可以导入吗?导入有什么影响?
首先一定是可以导入的,但是导入之后是否有效果?那应该分以下情况讨论。
- 网络模型完全对应:这种情况可以导入,而且微调效果更好
- 网络模型不完全对应(小心这种情况)
- 只是输出层有部分变化,可以导入
- 中间层有变化,不建议导入
问题
- 预训练后的权重如何导入另一个网络模型?
- 预训练对应的网络模型A与未训练的网络模型结构B不对应?
2.1 两个网络模型A和B只有部分对应
2.2 集合关系上A属于B
2.3 集合关系上B属于A
方案
PyTorch文档
- torch.nn.modules.module.Module def load_state_dict(self,
state_dict: Dict[str, Tensor] | OrderedDict[str, Tensor],
strict: bool = …) -> None - 说明:将 state_dict 中的参数和缓冲区复制到此模块及其后代中。
- 如果 strict 为 True,则 state_dict 的键必须与此模块的torch.nn.Module.state_dict 函数返回的键完全匹配
- 参数
state_dict – 包含参数和持久缓冲区的字典。
strict – 是否严格强制:- attr:
state_dict
中的键与该模块的 :meth:~torch.nn.Module.state_dict
函数返回的键匹配。 默认值:“真”
- attr:
- 返回值:
- missing_keys 是包含缺失键的 str 列表
- unexpected_keys 是包含意外键的 str 列表
模型对应,完全导入
# demo1 完全加载权重
model = NET1()
state_dict = model.state_dict()
weights = torch.load(weights_path)['model_state_dict'] #读取预训练模型权重
model.load_state_dict(weights)
模型不完全对应
此一种情况经常出现在要修改预训练网络模型中某些层时,可能增加若干层,可能减少若干层,或上述两种情况皆有。
只有部分对应
两个模型中有部分是对应的,此种情况建议使用PyTorch中的load_state_dict所提供的参数:strict
将strict设置为False,可以在两个模型不同的情况下,仅加载相同键值部分。(保证各层的名字相同)
# demo2
model = NET2()
state_dict = model.state_dict()
weights = torch.load(weights_path)['model_state_dict'] #读取预训练模型权重
model.load_state_dict(weights, strict=False) #strict
A属于B
此种情况常见于,在网上download别人的预训练模型后,需要根据自己的任务,添加若干个层,而其他层保持不变。
# demo3
*****待测试
B属于A
此种情况常见于从网上download别人的预训练模型后,因为某些限制,需要对模型进行精简,只删除若干个层,其他层保持不变。
# demo4
*****待测试
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY