PyTorch保存和加载模型

保存和加载模型

在PyTorch中使用torch.save来保存模型的结构和参数,有两种保存方式:

1
2
3
4
5
# 方式一:保存模型的结构信息和参数信息
torch.save(model, './model.pth')
 
# 方式二:仅保存模型的参数信息
torch.save(model.state_dict(), './model_state.pth')

相应的,有两种加载模型的方式:

1
2
3
4
5
# 方式一:加载完整的模型结构和参数信息,在网络较大时加载时间比较长,同时存储空间也比较大
model1= torch.load('model.pth')  
 
# 方式二:需先搭建网络模型model2,然后通过下面的语句加载参数
model2.load_state_dic(torch.load('model_state.pth'))

注:用以上的方法保存模型时,可能会遇到UserWarning: Couldn't retrieve source code for container of type Net. It won't be checked for correctness upon loading."type " + obj.__name__ + ". It won't be checked ",可参考这篇知乎文章解决这类警告。

示例

例子来自莫烦Python

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import matplotlib.pyplot as plt
 
# fake data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1# x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)
 
 
def save():
    # save net1
    net1 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )
    optimizer = torch.optim.SGD(net1.parameters(), lr=0.3)
    loss_func = torch.nn.MSELoss()
 
    for t in range(100):
        prediction = net1(x)
        loss = loss_func(prediction, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
    # plot result
    plt.figure(1, figsize=(10, 3))
    plt.subplot(131)
    plt.title('Net1')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
 
    # 2 ways to save the net
    torch.save(net1, 'net.pkl'# save entire net
    torch.save(net1.state_dict(), 'net_params.pkl')   # save only the parameters
 
 
def restore_net():
    # restore entire net1 to net2
    net2 = torch.load('net.pkl')
    prediction = net2(x)
 
    # plot result
    plt.subplot(132)
    plt.title('Net2')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
 
 
def restore_params():
    # restore only the parameters in net1 to net3
    net3 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )
 
    # copy net1's parameters into net3
    net3.load_state_dict(torch.load('net_params.pkl'))
    prediction = net3(x)
 
    # plot result
    plt.subplot(133)
    plt.title('Net3')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    plt.show()
 
# save net1
save()
 
# restore entire net (may slow)
restore_net()
 
# restore only the net parameters
restore_params()

运行结果:

posted @   Picassooo  阅读(462)  评论(0编辑  收藏  举报
编辑推荐:
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
阅读排行:
· Manus爆火,是硬核还是营销?
· 终于写完轮子一部分:tcp代理 了,记录一下
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 单元测试从入门到精通
点击右上角即可分享
微信分享提示