PyTorch 介绍 | 保存和加载模型
本节我们将会看到如何保存模型状态、加载和运行模型预测
import torch import torchvision.models as models
保存和加载模型权重
PyTorch模型在一个称为 state_dict
的内部状态字典内保存了学习的参数,可以通过 torch.save
实现这一过程。
model = models.vgg16(pretrained=True) torch.save(model.state_dict(), 'model_weights.pth')
为了加载模型参数,你需要首先创建一个相同模型的实体,然后使用 load_state_dict()
加载参数。
model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights model.load_state_dict(torch.load('model_weights.pth')) model.eval()
注意:在推理前,确保调用 model.eval()
设置dropout和batch normalization层是评估模式,否则将产生不一致的推断结果。
使用Shapes保存和加载模型
当加载模型权重时,我们需要首先初始化模型类,因为该类定义了网络结构。我们可能想将模型权重和该类的结构保存在一起,在这种情况下,可以将 model
(而不是model.state_dict()
)传入保存函数。
torch.save(model, 'model.pth')
加载
model = torch.load('model.pth')
注意:这种方法在序列化模型时使用Python pickle模块,因此,它依赖于加载模型时可用的实际类的定义。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· winform 绘制太阳,地球,月球 运作规律
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· Manus的开源复刻OpenManus初探
· 写一个简单的SQL生成工具
· AI 智能体引爆开源社区「GitHub 热点速览」