Pytorch 模型参数保存 + 如何查看模型参数
每次机器模型训练完成后,都直接退出了。
没有仔细的研究模型中各个参数到底是怎么样的
直到前几天看到大神将10层CNN每一步都展示出来的Github, 惊为天人那https://poloclub.github.io/cnn-explainer/
于是我也想看看,首先就是将模型中的参数保存下来
官网推荐了两种方法
1. 只保存模型参数
保存:
torch.save(the_model.state_dict(), PATH)
重新加载:由于只保存了参数,重新加载时,需要创造一个新的模型框架来装参数
restore_model = TheModelClass(*args, **kwargs)
restore_model.load_state_dict(torch.load(PATH))
2. 保存整个模型
保存:
torch.save(the_model, PATH)
重新加载:保存了整个模型,不需要创造新模型
restore_model = torch.load(PATH)
最后,查看模型参数
restore_model.state_dict()