import torch
import torchvision.models as models

保存和加载模型权重

pytorch模型在内部状态字典(叫做state_dict)中保存了学习参数,这些通过torch.save方法持久化。

model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')

为了加载模型权重,首先需要创建同样模型的实例。然后使用load_state_dict()加载参数

model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

be sure to call model.eval() method before inferencing to set the dropout and batch normalization layers to evaluation mode. Failing to do this will yield inconsistent inference results.

保存和加载有形状的模型

有时需要保存类的结构和模型,这样就可以传递model给save方法。

torch.save(model, 'model.pth')

这种保存方式,对应的加载方法。

model = torch.load('model.pth')

This approach uses Python pickle module when serializing the model, thus it relies on the actual class definition to be available when loading the model.

 

 posted on 2024-03-24 20:09  会飞的金鱼  阅读(10)  评论(0)    收藏  举报