加载训练的模型参数并继续训练
参考连接:
https://blog.csdn.net/hungryof/article/details/81364487
保存模型:
torch.save(model.state_dict(), model_path)
加载模型时一般用
model.load_state_dict(torch.load(model_path))
其中,model_path 为模型路径。
值得注意的是:torch.load
返回的是一个 OrderedDict.
但是可能这样加载模型继续训练时,会出现一些问题,故可以改为:
model.load_state_dict(torch.load(model_path), strict=False)
pytorch官网:
https://pytorch.org/tutorials/beginner/saving_loading_models.html
感觉写的蛮详细的博客: