保存和恢复神经网络

转自莫烦大神,转载原因是想把所有相关内容收集到自己的博客中,方便系统的学习。

两种保存方法,1是保存整个神经网络;2是只保存神经网络的所有参数。

一、保存神经网络

1保存整个神经网络。

 torch.save(net1,"net1.pkl")

net1为我想要保存的网络,net1.pkl为文件名,保存的格式只能是.pkl

2,保存神经网络参数

torch.save(net1.state_dict(),"net1_parmaer.pkl") 

二、恢复神经网络

1恢复完整神经网络(直接load())

net2=torch.load("net1.pkl")

2.从参数中恢复神经网络

需先构建与所要恢复的神经网络相同结构,再load参数。

3,完整程序如下

import torch
import matplotlib.pyplot as plt
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)

def save():
    net1 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),  # 一层神经层
        torch.nn.ReLU(),  # 加激励函数,relu相当于类
        torch.nn.Linear(10, 1),
    )
    optimizer=torch.optim.SGD(net1.parameters(),lr=0.5)
    loss_func=torch.nn.MSELoss()
    for t in range(100):
        prediction=net1(x)
        loss=loss_func(prediction,y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    #画图
    plt.figure(1, figsize=(10, 3))
    plt.subplot(131)
    plt.title('Net1')
    plt.scatter(x.data.numpy(), y.data.numpy())   #实际数据
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)     #回归曲线


    torch.save(net1,"net1.pkl")   #保存整个神经网络
    torch.save(net1.state_dict(),"net1_parmaer.pkl")   #保存神经网络中的所有参数

def restore_net():
    net2=torch.load("net1.pkl")
    prediction2=net2(x)

    plt.subplot(132)
    plt.title('Net2')
    plt.scatter(x.data.numpy(), y.data.numpy())   #实际数据
    plt.plot(x.data.numpy(), prediction2.data.numpy(), 'r-', lw=5)     #回归曲线


def restore_paramers():
    net3=torch.nn.Sequential(
        torch.nn.Linear(1, 10),  # 一层神经层
        torch.nn.ReLU(),  # 加激励函数,relu相当于类
        torch.nn.Linear(10, 1),
    )
    net3.load_state_dict(torch.load("net1_parmaer.pkl"))  #先构建网络在,再加载参数
    prediction3 = net3(x)

    plt.subplot(133)
    plt.title('Net3')
    plt.scatter(x.data.numpy(), y.data.numpy())  # 实际数据
    plt.plot(x.data.numpy(), prediction3.data.numpy(), 'r-', lw=5)  # 回归曲线
    plt.show()

save()
restore_net()

运行结果:

 

posted @ 2018-12-26 11:21  小小小小小码农  阅读(1575)  评论(0编辑  收藏  举报