1 保存和加载整个模型
torch.save(model_object, 'model.pth')
model = torch.load('model.pth')
2 仅保存和加载模型参数
torch.save(model_obj.state_dict(), 'params.pth')
model_obj.load_state_dict(torch.load('params.pth'))
3 选择保存网络中的一部分参数或者额外保存其余的参数
torch.save({'state_dict': net.state_dict(), 'linear1':net.linear1.state_dict(),
'optimizer': optimizer.state_dict(),'num_epoch':num_epochs },
'detail.pth')
model = torch.load('detail.pth')
net = DNN(num_input,num_hidden1,num_hidden2,num_output)
net.load_state_dict(model['state_dict'])
参考:
[日常] PyTorch 预训练模型,保存,读取和更新模型参数以及多 GPU 训练模型
因上求缘,果上努力~~~~ 作者:别关注我了,私信我吧,转载请注明原文链接:https://www.cnblogs.com/BlairGrowing/p/15981277.html
分类:
编程问题
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 25岁的心里话
· 闲置电脑爆改个人服务器(超详细) #公网映射 #Vmware虚拟网络编辑器
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· 零经验选手,Compose 一天开发一款小游戏!
· 一起来玩mcp_server_sqlite,让AI帮你做增删改查!!