pytorch模型save() load() load_state_dict()
torch.save()
''' 模型保存: 1,保存整个网络模型,网络结构+权重参数 torch.save(model,'net.pth') 2,只保存模型的权重 torch.save(model.state_dict(),'net_params.pth') 参数(速度快,占内存少) 3,保存加载自定义模型 checkpoint={'modle':ClassNet(), 网络结构 'model_state_dict':model.state_dict(), 模型的权重参数 'optimize_state_dict':optimizer.state_dict(), 优化器参数 'epoch':epoch 其他信息:有时我们需要保存一些其他的信息,比如epoch, batch_size等超参数 } torch.save(checkpoint,'checkpoint.pkl') '''
torch.load()
# 从文件加载用torch.save()保存的对象 model_file_path = 'models/001-resnet18-2c-acc 94.80 97.48.pth' # 模型权重文件 checkpoint = torch.load(model_file_path)
torch.load_state_dict()
model = resnet18() # 加载模型 # torch.load_state_dict()函数就是用于将预训练的参数权重加载到新的模型之中 model.load_state_dict(checkpoint)
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 地球OL攻略 —— 某应届生求职总结
· 周边上新:园子的第一款马克杯温暖上架
· Open-Sora 2.0 重磅开源!
· 提示词工程——AI应用必不可少的技术
· .NET周刊【3月第1期 2025-03-02】