Pytorch model saving and loading 模型保存和读取

It is really useful to save and reload the model and its parameters during or after training in deep learning.

Pytorch provides two methods to do so.

1. Only restore the parameters (recommended)

torch.save(the_model.state_dict(), PATH)    # save parameters to PATH

the_model = TheModelClass(*args, **kwargs)    # declare the_model as a object of TheModelClass
the_model.load_state_dict(torch.load(PATH))    # load parameters from PATH

 

2. Save all structure and parameters

torch.save(the_model, PATH)

the_model = torch.load(PATH)

 

3. Get parameters of certain layer

params=model.state_dict() 
for k,v in params.items():
    print(k)    # print the variable names in networks
print(params['conv1.weight'])   #print conv1's weight
print(params['conv1.bias'])   #print conv1's bias  

  

 

reference:http://www.pytorchtutorial.com/pytorch-note5-save-and-restore-models/

  

posted @ 2018-02-21 09:51  蠢材少年  阅读(268)  评论(0编辑  收藏  举报