19-神经网络-模型的加载和保存
model.eval()和torch.no_grad()的区别 https://blog.csdn.net/qq_41813454/article/details/135129279
import torch
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.hidden = nn.Linear(20, 256)
self.output = nn.Linear(256, 10)
def forward(self, x):
return self.output(F.relu(self.hidden(x)))
net = MLP()
x = torch.randn(1, 20)
y = net(x)
print(y)
torch.save(net.state_dict(), 'mlp.params')
clone = MLP()
clone.load_state_dict(torch.load('mlp.params'))
clone.eval()
y_clone = clone(x)
print(y_clone == y)