保存提取神经网络
一般来说,神经网络是用来训练数据,做实验等等,但每次关闭神经网络或者关机后,训练好的结果等都自动清除,为了解决这一问题,需要对神经网络训练后的结果进行保存、提取。
一、保存神经网络
搭建神经网络、训练数据、优化等操作都需要在save()函数中进行,代码如下:
def save():
#快速搭建
net1 = torch.nn.Sequential(
torch.nn.Linear(2, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 2),
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss()
# 训练
for t in range(100):
prediction = net1(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
接下来有两种方法保存:
torch.save(net1, 'net.pkl') # 保存整个网络
torch.save(net1.state_dict(), 'net_params.pkl') # 只保存网络中的参数
二、提取神经网络
2.1 提取整个神经网络
def restore_net():
# restore entire net1 to net2
net2 = torch.load('net.pkl')
2.2 只提取神经网络中的参数(速度更快)
此时,不能直接提取,如果只提取参数,那么在提取之前,需要构建一个和net1完全一样的神经网络了net3,然后把net1的参数复制到net3中
def restore_params():
# 新建 net3
net3 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
# 将保存的参数复制到 net3
net3.load_state_dict(torch.load('net_params.pkl'))