[pytorch]模型参数保存与加载
最简单的情况
模型保存:
torch.save(model.state_dict(), PATH)
模型加载:
model.load_state_dict(torch.load(PATH))
此时保存的是一个字典,key为model中的weight或bias名,如"linear1.weight"或“linear2.bias”
有时我们使用了优化器
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
我们在保存参数时需要同时保存优化器中的参数:
save_state = {'net':model.state_dict(), 'optimizer':optimizer.state_dict()}
torch.save(save_state, PATH)
在加载时,
model=MyModel()
model.load_state_dict(torch.load("PATH")['net'])
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
optimizer.load_state_dict(torch.load("lab3_lstmtest_0614.pth")['optimizer'])
这样即保存和加载了模型和优化器参数,继续上一次训练。