保存和读取神经网络
保存和读取神经网络
神经网络被训练到一个状态以后,我们希望能够把这个状态保存下来供下次使用。
以下代码使用两种方式来保存和读取神经网络:
import numpy as np
import torch
from torch.autograd import Variable
torch.manual_seed(2) # make result reproducible
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # unsqueeze:enlarge the dimension for figure
y = x.pow(2) + 0.2 * torch.rand(x.size())
x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)
origin_net = torch.nn.Sequential( # build a net
torch.nn.Linear(1, 10), # 1 input,10 output
torch.nn.ReLU(), # activation function
torch.nn.Linear(10, 1) # 10 input,1 input
)
optimizer = torch.optim.SGD(origin_net.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss()
for t in range(1000):
prediction = origin_net(x)
loss = loss_func(prediction, y)
optimizer.zero_grad() # optimize until loss_grad to 0
loss.backward()
optimizer.step()
# way1:save and load the whole net
torch.save(origin_net, 'net.pkl') # save the whole net
net1 = torch.load('net.pkl') # load the whole net
# way2:save and load the parameters of net
torch.save(origin_net.state_dict(), 'params.pkl') # save the parameters of net
net2 = torch.nn.Sequential( # build an empty net
torch.nn.Linear(1, 10), # 1 input,10 output
torch.nn.ReLU(), # activation function
torch.nn.Linear(10, 1) # 10 input,1 input
)
net2.load_state_dict(torch.load('params.pkl')) # load params to the empty net
input_test = torch.Tensor(np.array([0.5]))
print(net1(input_test)) # test way1
print(net2(input_test)) # test way2
两种方法分别保存整个神经网络与神经网络的参数。保存后读取整个网络只需对一个网络变量赋值,读取参数时需要新建一个结构相同的空白网络,再往空白网络中加载参数。
输出结果:
tensor([0.3395], grad_fn=<AddBackward0>)
tensor([0.3395], grad_fn=<AddBackward0>)
两种保存方法加载出来后,用一样的输入测试,输出结果一致。