Pytorch模型的保存和加载
1、直接保存加载模型
(1)保存和加载整个模型
# 保存模型 torch.save(model, 'model.pth\pkl\pt') #一般形式torch.save(net, PATH) # 加载模型 model = torch.load('model.pth\pkl\pt') #一般形式为model_dict=torch.load(PATH)
(2)仅保存加载模型参数(推荐使用,模型需要提前手动构建)
# 保存模型参数 torch.save(model.state_dict(), 'model.pth\pkl\pt') #一般形式为torch.save(net.state_dict(),PATH) # 加载模型参数 model.load_state_dict(torch.load('model.pth\pkl\pt') #一般形式为model_dict=model.load_state_dict(torch.load(PATH))
state_dict() 是一个Python字典,将每一层映射成它的参数张量。注意只有带有可学习参数的层(卷积层、全连接层等),以及注册的缓存(batchnorm的运行平均值)在state_dict 中才有记录。state_dict同样包含优化器对象,存储了优化器的状态,所使用到的超参数。
2. 保存加载用于推理的常规Checkpoint/或继续训练
相当于保存一个模型文件
if (epoch+1) % checkpoint_interval == 0: checkpoint = {"model_state_dict": net.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "epoch": epoch} path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch) torch.save(checkpoint, path_checkpoint) #或者 #保存 torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, ... }, PATH) #加载 model = TheModelClass(*args, **kwargs) optimizer = TheOptimizerClass(*args, **kwargs) checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] model.eval() # - 或者 - model.train()
在保存用于推理或者继续训练的常规检查点的时候,除了模型的state_dict之外,还必须保存其他参数。保存优化器的state_dict也非常重要,因为它包含了模型在训练时候优化器的缓存和参数。除此之外,还可以保存停止训练时epoch数,最新的模型损失,额外的torch.nn.Embedding层等。
要保存多个组件,则将它们放到一个字典中,然后使用torch.save()序列化这个字典。一般来说,使用.tar文件格式来保存这些检查点。
加载各个组件,首先初始化模型和优化器,然后使用torch.load()加载保存的字典,然后可以直接查询字典中的值来获取保存的组件。
同样,评估模型的时候一定不要忘了调用model.eval()。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY