学习笔记14:模型保存
转自:https://www.cnblogs.com/miraclepbc/p/14361926.html
保存训练过程中使得测试集上准确率最高的参数
import copy
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0
train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(extend_epoch):
epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch, model, train_dl, test_dl)
if epoch_test_acc > best_acc:
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = epoch_test_acc
train_loss.append(epoch_loss)
train_acc.append(epoch_acc)
test_loss.append(epoch_test_loss)
test_acc.append(epoch_test_acc)
model.load_state_dict(best_model_wts)
保存模型
PATH = 'E:/my_model.pth'
torch.save(model.state_dict(), PATH)
重新加载模型
new_model = models.resnet101(pretrained = True)
in_f = new_model.fc.in_features
new_model.fc = nn.Linear(in_f, 4)
new_model.load_state_dict(torch.load(PATH))
测试是否加载成功
new_model.to(device)
test_correct = 0
test_total = 0
new_model.eval()
with torch.no_grad():
for x, y in test_dl:
x, y = x.to(device), y.to(device)
y_pred = new_model(x)
loss = loss_func(y_pred, y)
y_pred = torch.argmax(y_pred, dim = 1)
test_correct += (y_pred == y).sum().item()
test_total += y.size(0)
epoch_test_acc = test_correct / test_total
print(epoch_test_acc)
分类:
AI
, AI / Pytorch
标签:
Pytorch
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· winform 绘制太阳,地球,月球 运作规律
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· AI与.NET技术实操系列(五):向量存储与相似性搜索在 .NET 中的实现
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
2023-06-04 Kubescape入门