PyTorch模型加载与保存的最佳实践
一般来说PyTorch有两种保存和读取模型参数的方法。但这篇文章我记录了一种最佳实践,可以在加载模型时避免掉一些问题。
传统方案:
第一种方案是保存整个模型:
torch.save(model_object, 'model.pth')
第二种方法是保存模型网络参数:
torch.save(model_object.state_dict(), 'params.pth')
加载的时候分别这样加载:
model = torch.load('model.pth')
以及:
model_object.load_state_dict(torch.load('params.pth'))
改进的方案
注意到这个方案是因为模型在加载之后,loss会飙升之后再慢慢降回来。查阅有关分析之后,判定是优化器optimizer的问题。
如果模型的保存是为了恢复训练状态,那么可以考虑同时保存优化器optimizer的参数:
state = { 'epoch': epoch, 'net': model.state_dict(), 'optimizer': optimizer.state_dict(), ... } torch.save(state, filepath)
然后这样加载:
checkpoint = torch.load(model_path) model.load_state_dict(checkpoint['net']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] + 1
如果模型的保存是为了方便以后进行validation和test,可以在加载完之后制定model.eval()固定dropout和BN层。
快去成为你想要的样子!