PyTorch中保存和加载模型
一般保存和加载模型的后缀名都是.ph
或者.pth
保存与加载整个模型(不推荐)
该方法会保存整个模型的网络架构以及当前训练的参数(w、b等)
保存模型
torch.save(model, PATH)
加载模型
model = torch.load(PATH)
举例:
torch.save(model_one, './my_model.pt')
model_two = torch.load('./my_model.pt')
仅仅保存模型的参数(推荐)
保存模型
torch.save(model.state_dict(), PATH)
加载模型
model = TheModelClass(*args, **kwargs) # 初始化模型时要用原来保存该参数的模型类来初始化
model.load_state_dict(torch.load(PATH))
举例:
torch.save(model_one.state_dict(), './my_model.pt') # 保存模型参数
model_two = Net(1, 10, 1) # 初始化新模型
model_two.load_state_dict(torch.load(PATH)) # 加载模型参数