PyTorch-模型保存与加载
保存:
model = LinearRegression()
# ......各种操作
model.eval()
#训练完成,保存状态字典到linear.pkl
torch.save(model.state_dict(), './linear.pkl')
加载:
model = LinearRegression()
model.load_state_dict(torch.load('linear.pth'))
#...各种使用,比如预测...
x_test=np.arrar([..............])
x_test = torch.from_numpy(x_test)
predict_y = model(Variable(x_test))