pytorch模型save() load() load_state_dict()

torch.save()
'''
模型保存:
1,保存整个网络模型,网络结构+权重参数 
torch.save(model,'net.pth')   

2,只保存模型的权重
torch.save(model.state_dict(),'net_params.pth')  参数(速度快,占内存少)

3,保存加载自定义模型
checkpoint={'modle':ClassNet(),  网络结构
             'model_state_dict':model.state_dict(),         模型的权重参数
             'optimize_state_dict':optimizer.state_dict(),  优化器参数
             'epoch':epoch     其他信息:有时我们需要保存一些其他的信息,比如epoch, batch_size等超参数
             }
torch.save(checkpoint,'checkpoint.pkl')
'''

 torch.load()

# 从文件加载用torch.save()保存的对象
model_file_path = 'models/001-resnet18-2c-acc 94.80 97.48.pth'  # 模型权重文件
checkpoint = torch.load(model_file_path)

 torch.load_state_dict()

model = resnet18()    # 加载模型
# torch.load_state_dict()函数就是用于将预训练的参数权重加载到新的模型之中
model.load_state_dict(checkpoint) 

 

posted @ 2022-06-22 14:56  cheng4632  阅读(239)  评论(0编辑  收藏  举报