Pytorch model saving and loading 模型保存和读取
It is really useful to save and reload the model and its parameters during or after training in deep learning.
Pytorch provides two methods to do so.
1. Only restore the parameters (recommended)
torch.save(the_model.state_dict(), PATH) # save parameters to PATH the_model = TheModelClass(*args, **kwargs) # declare the_model as a object of TheModelClass the_model.load_state_dict(torch.load(PATH)) # load parameters from PATH
2. Save all structure and parameters
torch.save(the_model, PATH) the_model = torch.load(PATH)
3. Get parameters of certain layer
params=model.state_dict() for k,v in params.items(): print(k) # print the variable names in networks print(params['conv1.weight']) #print conv1's weight print(params['conv1.bias']) #print conv1's bias
reference:http://www.pytorchtutorial.com/pytorch-note5-save-and-restore-models/