PyTorch中保存和加载模型

一般保存和加载模型的后缀名都是.ph或者.pth

官方文档

保存与加载整个模型(不推荐)

该方法会保存整个模型的网络架构以及当前训练的参数(w、b等)

保存模型

torch.save(model, PATH)

加载模型

model = torch.load(PATH)

举例:

torch.save(model_one, './my_model.pt')
model_two = torch.load('./my_model.pt')

仅仅保存模型的参数(推荐)

保存模型

torch.save(model.state_dict(), PATH)

加载模型

model = TheModelClass(*args, **kwargs)  # 初始化模型时要用原来保存该参数的模型类来初始化
model.load_state_dict(torch.load(PATH))

举例:

torch.save(model_one.state_dict(), './my_model.pt')  # 保存模型参数
model_two = Net(1, 10, 1)  # 初始化新模型
model_two.load_state_dict(torch.load(PATH))  # 加载模型参数
posted @ 2020-11-29 19:22  火锅先生  阅读(253)  评论(0编辑  收藏  举报