pytorch加载预训练模型
(1) 保存和加载整个模型
1 2 3 4 | # 模型保存 torch.save(model, 'model.pth' ) # 模型加载 model = torch.load( 'model.pth' ) |
(2) 仅仅保存模型参数以及分别加载模型结构和参数
1 2 3 4 5 | # 模型参数保存 torch.save(model.state_dict(), 'model_param.pth' ) # 模型参数加载,加载预训练模型 model = ModelClass(...) model.load_state_dict(torch.load( 'model_param.pth' )) |
加载部分预训练模型
1 2 3 4 5 6 7 8 9 10 11 12 | resnet152 = models.resnet152(pretrained = True ) pretrained_dict = resnet152.state_dict() """加载torchvision中的预训练模型和参数后通过state_dict()方法提取参数 也可以直接从官方model_zoo下载: pretrained_dict = model_zoo.load_url(model_urls['resnet152'])""" model_dict = model.state_dict() # 将pretrained_dict里不属于model_dict的键剔除掉,只加载重复的网络结构的参数 pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # 更新现有的model_dict model_dict.update(pretrained_dict) # 加载我们真正需要的state_dict,将更新好的模型加载训练 model.load_state_dict(model_dict) |
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· Manus的开源复刻OpenManus初探
· AI 智能体引爆开源社区「GitHub 热点速览」
· 从HTTP原因短语缺失研究HTTP/2和HTTP/3的设计差异
· 三行代码完成国际化适配,妙~啊~
2018-03-25 tensorflow函数介绍(4)
2018-03-25 python其他篇(1)