PyTorch——模型保存与模型加载(一)

模型保存与加载有两种方式,本文暂时只讨论模型参数方式

1> 单GPU

保存

1 torch.save(model.state_dict(), "model.pth")

加载

1 model = SimpleNet()
2 model.load_state_dict(torch.load("./model.pth"))

2> 多GPU 

保存

1 torch.save(model.module.state_dict(), "./model.pth")

加载

1 mdoel = SimpleNet()
2 model.load_state_dict(torch.load("./model.pth"))

 

posted @ 2021-03-30 10:36  一剑光寒十四州  阅读(96)  评论(0编辑  收藏  举报