Pytorch加载.pth文件
1. .pth文件
(The weights of the model have been saved in a .pth file, which is nothing but a pickle file of the model’s tensor parameters.
We can load those into resnet18 using the model’s load _state_dict method.)
.pth文件报存了模型的权重,这个文件只是一个模型张量参数的pickle文件。
我们可以使用模型的load _state_dict方法将它们加载到 resnet18 中
2. 加载
2.1 如果.pth文件只保存了参数,则如下:
1 import torch 2 from torch.serialization import load 3 import torchvision.models as models 4 5 # pretrained=True使用预训练的模型 6 resnet18 = models.resnet18(pretrained=True)#创建实例,模型下载.Pth文件 7 model_path = 'D:/python_code/resnet18/resnet18-5c106cde.pth' 8 model_data = torch.load(model_path) 9 resnet18.load_state_dict(model_data) 10 print(resnet18)
输出为:
2.2 如果.pth文件保存的是整个网络结构+参数,则:
1 import torchvision.models as models 2 3 # pretrained=True就可以使用预训练的模型 4 resnet18 = models.resnet18(pretrained=True) 5 print(resnet18)
输出为:
参考:https://blog.csdn.net/u014264373/article/details/85332181
https://blog.csdn.net/u013679159/article/details/104253030