保存和读取神经网络

保存和读取神经网络

神经网络被训练到一个状态以后,我们希望能够把这个状态保存下来供下次使用。

以下代码使用两种方式来保存和读取神经网络:

import numpy as np
import torch
from torch.autograd import Variable

torch.manual_seed(2)  # make result reproducible

x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # unsqueeze:enlarge the dimension for figure
y = x.pow(2) + 0.2 * torch.rand(x.size())
x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)

origin_net = torch.nn.Sequential(  # build a net
    torch.nn.Linear(1, 10),  # 1 input,10 output
    torch.nn.ReLU(),  # activation function
    torch.nn.Linear(10, 1)  # 10 input,1 input
)
optimizer = torch.optim.SGD(origin_net.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss()

for t in range(1000):
    prediction = origin_net(x)
    loss = loss_func(prediction, y)
    optimizer.zero_grad()  # optimize until loss_grad to 0
    loss.backward()
    optimizer.step()

# way1:save and load the whole net
torch.save(origin_net, 'net.pkl')  # save the whole net
net1 = torch.load('net.pkl')  # load the whole net

# way2:save and load the parameters of net
torch.save(origin_net.state_dict(), 'params.pkl')  # save the parameters of net
net2 = torch.nn.Sequential(  # build an empty net
    torch.nn.Linear(1, 10),  # 1 input,10 output
    torch.nn.ReLU(),  # activation function
    torch.nn.Linear(10, 1)  # 10 input,1 input
)
net2.load_state_dict(torch.load('params.pkl'))  # load params to the empty net

input_test = torch.Tensor(np.array([0.5]))
print(net1(input_test))  # test way1
print(net2(input_test))  # test way2



两种方法分别保存整个神经网络与神经网络的参数。保存后读取整个网络只需对一个网络变量赋值,读取参数时需要新建一个结构相同的空白网络,再往空白网络中加载参数。

输出结果:

tensor([0.3395], grad_fn=<AddBackward0>)
tensor([0.3395], grad_fn=<AddBackward0>)

两种保存方法加载出来后,用一样的输入测试,输出结果一致。

posted on 2021-10-12 20:39  菜小疯  阅读(159)  评论(0编辑  收藏  举报