pytorch模型save() load() load_state_dict()

torch.save()
复制代码
'''
模型保存:
1,保存整个网络模型,网络结构+权重参数 
torch.save(model,'net.pth')   

2,只保存模型的权重
torch.save(model.state_dict(),'net_params.pth')  参数(速度快,占内存少)

3,保存加载自定义模型
checkpoint={'modle':ClassNet(),  网络结构
             'model_state_dict':model.state_dict(),         模型的权重参数
             'optimize_state_dict':optimizer.state_dict(),  优化器参数
             'epoch':epoch     其他信息:有时我们需要保存一些其他的信息,比如epoch, batch_size等超参数
             }
torch.save(checkpoint,'checkpoint.pkl')
'''
复制代码

 torch.load()

# 从文件加载用torch.save()保存的对象
model_file_path = 'models/001-resnet18-2c-acc 94.80 97.48.pth'  # 模型权重文件
checkpoint = torch.load(model_file_path)

 torch.load_state_dict()

model = resnet18()    # 加载模型
# torch.load_state_dict()函数就是用于将预训练的参数权重加载到新的模型之中
model.load_state_dict(checkpoint) 

 

posted @   cheng4632  阅读(253)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 地球OL攻略 —— 某应届生求职总结
· 周边上新:园子的第一款马克杯温暖上架
· Open-Sora 2.0 重磅开源!
· 提示词工程——AI应用必不可少的技术
· .NET周刊【3月第1期 2025-03-02】
点击右上角即可分享
微信分享提示