import torch import torchvision.models as models
保存和加载模型权重
pytorch模型在内部状态字典(叫做state_dict)中保存了学习参数,这些通过torch.save方法持久化。
model = models.vgg16(weights='IMAGENET1K_V1') torch.save(model.state_dict(), 'model_weights.pth')
为了加载模型权重,首先需要创建同样模型的实例。然后使用load_state_dict()加载参数
model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model model.load_state_dict(torch.load('model_weights.pth')) model.eval()
be sure to call model.eval()
method before inferencing to set the dropout and batch normalization layers to evaluation mode. Failing to do this will yield inconsistent inference results.
保存和加载有形状的模型
有时需要保存类的结构和模型,这样就可以传递model给save方法。
torch.save(model, 'model.pth')
这种保存方式,对应的加载方法。
model = torch.load('model.pth')
This approach uses Python pickle module when serializing the model, thus it relies on the actual class definition to be available when loading the model.