机器学习——模型保存和加载
以PyTorch为例
1 2 3 4 5 6 7 8 9 10 11 12 | class MLP(nn.Module): def __init__( self ): super ().__init__() self .hidden = nn.Linear( 20 , 256 ) self .output = nn.Linear( 256 , 10 ) def forward( self , x): return self .output(F.relu( self .hidden(x))) net = MLP() X = torch.randn(size = ( 2 , 20 )) Y = net(X) |
接下来,我们将模型的参数存储在一个叫做“mlp.params”的文件中。
1 | torch.save(net.state_dict(), 'mlp.params' ) |
为了恢复模型,我们实例化了原始多层感知机模型的一个备份。 这里我们不需要随机初始化模型参数,而是直接读取文件中存储的参数。
1 2 3 | clone = MLP() clone.load_state_dict(torch.load( 'mlp.params' )) clone. eval () |
MLP( (hidden): Linear(in_features=20, out_features=256, bias=True) (output): Linear(in_features=256, out_features=10, bias=True) )
注意,model.eval()表示将模型切换到评估模式,模型在训练时通常会使用一些辅助技巧,比如dropout、batch normalization等,这些技巧在评估时需要被关闭。具体如下:
1. Dropout: 在训练时随机使一部分神经元失活,以防止过拟合。在评估时需要关闭dropout,使用全部的神经元进行预测。
2. Batch Normalization: 对每个batch做数据规范化,使训练高效稳定。评估时用整个训练数据集的均值和方差做规范化。
由于两个实例具有相同的模型参数,在输入相同的X
时, 两个实例的计算结果应该相同。 让我们来验证一下。
1 2 | Y_clone = clone(X) Y_clone = = Y |
tensor([[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True]])
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律
2022-11-03 1049 数列的片段和
2022-11-03 iomanip库中的常用函数