pytorch搭建网络,保存参数,恢复参数
这是看过莫凡python的学习笔记。
搭建网络,两种方式
(1)建立Sequential对象
import torch net = torch.nn.Sequential( torch.nn.Linear(2,10), torch.nn.ReLU(), torch.nn.Linear(10,2))
输出网络结构
Sequential( (0): Linear(in_features=2, out_features=10, bias=True) (1): ReLU() (2): Linear(in_features=10, out_features=2, bias=True) )
(2)建立网络类,继承torch.nn.module
class Net(torch.nn.Module): def __init__(self): super(Net,self).__init__() self.hidden = torch.nn.Linear(2,10) self.predict = torch.nn.Linear(10,2) def forward(self,x): x = F.relu(self.hidden(x)) x = self.predict(x) return x
输出和上面基本一样,略微不同
Net( (hidden): Linear(in_features=2, out_features=10, bias=True) (predict): Linear(in_features=10, out_features=2, bias=True) )
保存模型,两种方式
(1)保存整个网络,及网络参数
torch.save(net,'net.pkl')
(2)只保存网络参数
torch.save(net.state_dict(),'net_params.pkl')
恢复模型,两种方式
(1)加载整个网络,及参数
net2 = torch.load('net.pkl')
(2)加载参数,但需实现网络
net3 = torch.nn.Sequential( torch.nn.Linear(2,10), torch.nn.ReLU(), torch.nn.Linear(10,2)) net3.load_state_dict(torch.load('net_params.pkl'))