机器学习——模型保存和加载

以PyTorch为例

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20, 256)
        self.output = nn.Linear(256, 10)

    def forward(self, x):
        return self.output(F.relu(self.hidden(x)))

net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)

 接下来,我们将模型的参数存储在一个叫做“mlp.params”的文件中。

torch.save(net.state_dict(), 'mlp.params')

 

为了恢复模型,我们实例化了原始多层感知机模型的一个备份。 这里我们不需要随机初始化模型参数,而是直接读取文件中存储的参数。

clone = MLP()
clone.load_state_dict(torch.load('mlp.params'))
clone.eval()
MLP(
  (hidden): Linear(in_features=20, out_features=256, bias=True)
  (output): Linear(in_features=256, out_features=10, bias=True)
)

注意,model.eval()表示将模型切换到评估模式,模型在训练时通常会使用一些辅助技巧,比如dropout、batch normalization等,这些技巧在评估时需要被关闭。具体如下:

1. Dropout: 在训练时随机使一部分神经元失活,以防止过拟合。在评估时需要关闭dropout,使用全部的神经元进行预测。

2. Batch Normalization: 对每个batch做数据规范化,使训练高效稳定。评估时用整个训练数据集的均值和方差做规范化。

 

由于两个实例具有相同的模型参数,在输入相同的X时, 两个实例的计算结果应该相同。 让我们来验证一下。

Y_clone = clone(X)
Y_clone == Y
tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])

 

posted @ 2023-11-03 11:33  Yohoc  阅读(29)  评论(0编辑  收藏  举报