保存提取神经网络

  一般来说,神经网络是用来训练数据,做实验等等,但每次关闭神经网络或者关机后,训练好的结果等都自动清除,为了解决这一问题,需要对神经网络训练后的结果进行保存、提取。

一、保存神经网络

  搭建神经网络、训练数据、优化等操作都需要在save()函数中进行,代码如下:

def save():
  #
快速搭建   net1 = torch.nn.Sequential(
    torch.nn.Linear(
2, 10),     torch.nn.ReLU(),     torch.nn.Linear(10, 2),   )   optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)   loss_func = torch.nn.MSELoss()   # 训练   for t in range(100):     prediction = net1(x)     loss = loss_func(prediction, y)     optimizer.zero_grad()     loss.backward()     optimizer.step()

  接下来有两种方法保存:

torch.save(net1, 'net.pkl')  # 保存整个网络
torch.save(net1.state_dict(), 'net_params.pkl')   # 只保存网络中的参数 

 

二、提取神经网络

  2.1 提取整个神经网络

def restore_net():
    # restore entire net1 to net2
    net2 = torch.load('net.pkl')

  2.2 只提取神经网络中的参数(速度更快)

  此时,不能直接提取,如果只提取参数,那么在提取之前,需要构建一个和net1完全一样的神经网络了net3,然后把net1的参数复制到net3中

def restore_params():
    # 新建 net3
    net3 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )

    # 将保存的参数复制到 net3
    net3.load_state_dict(torch.load('net_params.pkl'))

 

posted @ 2021-11-26 13:44  Sunshine_y  阅读(127)  评论(0编辑  收藏  举报