PyTorch保存和加载模型
保存和加载模型
在PyTorch中使用torch.save来保存模型的结构和参数,有两种保存方式:
1 2 3 4 5 | # 方式一:保存模型的结构信息和参数信息 torch.save(model, './model.pth' ) # 方式二:仅保存模型的参数信息 torch.save(model.state_dict(), './model_state.pth' ) |
相应的,有两种加载模型的方式:
1 2 3 4 5 | # 方式一:加载完整的模型结构和参数信息,在网络较大时加载时间比较长,同时存储空间也比较大 model1 = torch.load( 'model.pth' ) # 方式二:需先搭建网络模型model2,然后通过下面的语句加载参数 model2.load_state_dic(torch.load( 'model_state.pth' )) |
注:用以上的方法保存模型时,可能会遇到UserWarning: Couldn't retrieve source code for container of type Net. It won't be checked for correctness upon loading."type " + obj.__name__ + ". It won't be checked ",可参考这篇知乎文章解决这类警告。
示例
例子来自莫烦Python
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | import torch import matplotlib.pyplot as plt # fake data x = torch.unsqueeze(torch.linspace( - 1 , 1 , 100 ), dim = 1 ) # x data (tensor), shape=(100, 1) y = x. pow ( 2 ) + 0.2 * torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1) def save(): # save net1 net1 = torch.nn.Sequential( torch.nn.Linear( 1 , 10 ), torch.nn.ReLU(), torch.nn.Linear( 10 , 1 ) ) optimizer = torch.optim.SGD(net1.parameters(), lr = 0.3 ) loss_func = torch.nn.MSELoss() for t in range ( 100 ): prediction = net1(x) loss = loss_func(prediction, y) optimizer.zero_grad() loss.backward() optimizer.step() # plot result plt.figure( 1 , figsize = ( 10 , 3 )) plt.subplot( 131 ) plt.title( 'Net1' ) plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-' , lw = 5 ) # 2 ways to save the net torch.save(net1, 'net.pkl' ) # save entire net torch.save(net1.state_dict(), 'net_params.pkl' ) # save only the parameters def restore_net(): # restore entire net1 to net2 net2 = torch.load( 'net.pkl' ) prediction = net2(x) # plot result plt.subplot( 132 ) plt.title( 'Net2' ) plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-' , lw = 5 ) def restore_params(): # restore only the parameters in net1 to net3 net3 = torch.nn.Sequential( torch.nn.Linear( 1 , 10 ), torch.nn.ReLU(), torch.nn.Linear( 10 , 1 ) ) # copy net1's parameters into net3 net3.load_state_dict(torch.load( 'net_params.pkl' )) prediction = net3(x) # plot result plt.subplot( 133 ) plt.title( 'Net3' ) plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-' , lw = 5 ) plt.show() # save net1 save() # restore entire net (may slow) restore_net() # restore only the net parameters restore_params() |
运行结果:
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· Manus爆火,是硬核还是营销?
· 终于写完轮子一部分:tcp代理 了,记录一下
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 单元测试从入门到精通