机器学习——模型保存和加载
以PyTorch为例
class MLP(nn.Module): def __init__(self): super().__init__() self.hidden = nn.Linear(20, 256) self.output = nn.Linear(256, 10) def forward(self, x): return self.output(F.relu(self.hidden(x))) net = MLP() X = torch.randn(size=(2, 20)) Y = net(X)
接下来,我们将模型的参数存储在一个叫做“mlp.params”的文件中。
torch.save(net.state_dict(), 'mlp.params')
为了恢复模型,我们实例化了原始多层感知机模型的一个备份。 这里我们不需要随机初始化模型参数,而是直接读取文件中存储的参数。
clone = MLP() clone.load_state_dict(torch.load('mlp.params')) clone.eval()
MLP( (hidden): Linear(in_features=20, out_features=256, bias=True) (output): Linear(in_features=256, out_features=10, bias=True) )
注意,model.eval()表示将模型切换到评估模式,模型在训练时通常会使用一些辅助技巧,比如dropout、batch normalization等,这些技巧在评估时需要被关闭。具体如下:
1. Dropout: 在训练时随机使一部分神经元失活,以防止过拟合。在评估时需要关闭dropout,使用全部的神经元进行预测。
2. Batch Normalization: 对每个batch做数据规范化,使训练高效稳定。评估时用整个训练数据集的均值和方差做规范化。
由于两个实例具有相同的模型参数,在输入相同的X
时, 两个实例的计算结果应该相同。 让我们来验证一下。
Y_clone = clone(X) Y_clone == Y
tensor([[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True]])