pytorch模型save() load() load_state_dict()
torch.save()
''' 模型保存: 1,保存整个网络模型,网络结构+权重参数 torch.save(model,'net.pth') 2,只保存模型的权重 torch.save(model.state_dict(),'net_params.pth') 参数(速度快,占内存少) 3,保存加载自定义模型 checkpoint={'modle':ClassNet(), 网络结构 'model_state_dict':model.state_dict(), 模型的权重参数 'optimize_state_dict':optimizer.state_dict(), 优化器参数 'epoch':epoch 其他信息:有时我们需要保存一些其他的信息,比如epoch, batch_size等超参数 } torch.save(checkpoint,'checkpoint.pkl') '''
torch.load()
# 从文件加载用torch.save()保存的对象 model_file_path = 'models/001-resnet18-2c-acc 94.80 97.48.pth' # 模型权重文件 checkpoint = torch.load(model_file_path)
torch.load_state_dict()
model = resnet18() # 加载模型 # torch.load_state_dict()函数就是用于将预训练的参数权重加载到新的模型之中 model.load_state_dict(checkpoint)