PyTorch-模型保存与加载

保存: 

model = LinearRegression()
# ......各种操作
model.eval()
#训练完成,保存状态字典到linear.pkl
torch.save(model.state_dict(), './linear.pkl')

加载:

model = LinearRegression()
model.load_state_dict(torch.load('linear.pth'))
#...各种使用,比如预测...
x_test=np.arrar([..............])
x_test = torch.from_numpy(x_test)
predict_y = model(Variable(x_test))

 

posted @ 2019-03-06 16:16  jj千寻  阅读(137)  评论(0编辑  收藏  举报