【笔记】PyTorch快速入门: 训练,保存和加载模型
优化模型参数
有了模型,接下来要进行训练、验证和测试。
前置代码
首先要加载数据,建立模型
import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import ToTensor, Lambda training_data = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor() ) test_data = datasets.FashionMNIST( root="data", train=False, download=True, transform=ToTensor() ) train_dataloader = DataLoader(training_data, batch_size=64) test_dataloader = DataLoader(test_data, batch_size=64) class NeuralNetwork(nn.Module): def __init__(self): super(NeuralNetwork, self).__init__() self.flatten = nn.Flatten() self.linear_relu_stack = nn.Sequential( nn.Linear(28*28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10), ) def forward(self, x): x = self.flatten(x) logits = self.linear_relu_stack(x) return logits model = NeuralNetwork()
超参数
定义三个超参数
- Epochs数:数据集迭代次数
- Batch size:单次训练样本数
- Learning Rate:学习速度
优化循环
接下来,我们进行多轮的优化,每轮叫一个epoch
每个epoch包含两部分,训练loop和验证/测试loop
Loss Function
PyTorch提供常见的Loss Functions
- nn.MSELoss (Mean Square Error)
- nn.NLLLoss (Negative Log Likelihood)
- nn.CrossEntropyLoss (交叉熵)
我们使用交叉熵,把原始结果logits放进去
loss_fn = nn.CrossEntropyLoss()
Optimizer
初始化优化器,给它需要优化的参数,和超参数Learning Rate
optimizer = torch.optim.SGC(model.parameters(),lr = learning_rate)
优化器在每个epoch里做三件事
optimizer.zero_grad()
将梯度清零loss.backward()
进行反向传播optimizer.step()
根据梯度调整参数
完整实现
在train_loop
里训练,test_loop
里测试
import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import ToTensor, Lambda training_data = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor() ) test_data = datasets.FashionMNIST( root="data", train=False, download=True, transform=ToTensor() ) train_dataloader = DataLoader(training_data, batch_size=64) test_dataloader = DataLoader(test_data, batch_size=64) class NeuralNetwork(nn.Module): def __init__(self): super(NeuralNetwork, self).__init__() self.flatten = nn.Flatten() self.linear_relu_stack = nn.Sequential( nn.Linear(28*28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10), ) def forward(self, x): x = self.flatten(x) logits = self.linear_relu_stack(x) return logits model = NeuralNetwork() learning_rate = 1e-3 batch_size = 64 epochs = 5 # Initialize the loss function loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) def train_loop(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) for batch, (X, y) in enumerate(dataloader): # Compute prediction and loss pred = model(X) loss = loss_fn(pred, y) # Backpropagation optimizer.zero_grad() loss.backward() optimizer.step() if batch % 100 == 0: loss, current = loss.item(), batch * len(X) print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") def test_loop(dataloader, model, loss_fn): size = len(dataloader.dataset) num_batches = len(dataloader) test_loss, correct = 0, 0 with torch.no_grad(): for X, y in dataloader: pred = model(X) test_loss += loss_fn(pred, y).item() correct += (pred.argmax(1) == y).type(torch.float).sum().item() test_loss /= num_batches correct /= size print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n") loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) epochs = 10 for t in range(epochs): print(f"Epoch {t + 1}\n-------------------------------") train_loop(train_dataloader, model, loss_fn, optimizer) test_loop(test_dataloader, model, loss_fn) print("Done!")
保存和加载模型
如何保存和加载训好的模型?
import torch import torchvision.models as models
保存和加载模型权重
通过torch.save
方法,可以将模型保存到state_dict
类型的字典里。
model = models.vgg16(pretrained=True) torch.save(model.state_dict(), 'model_weights.pth')
而要加载的话,需要先构造相同类型的模型,然后把参数加载进去
model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights model.load_state_dict(torch.load('model_weights.pth')) model.eval()
注意,一定要调一下model.eval()
,防止后续出错
保存和加载模型
上一种方法里,需要先实例化模型,再导入权值
有没有办法直接保存和加载整个模型呢?
我们用不传mode.state_dict()
参数,改为model
保存方式:
torch.save(model,'model.pth')
加载方式:
model = torch.load('model.pth')
本文来自博客园,作者:GhostCai,转载请注明原文链接:https://www.cnblogs.com/ghostcai/p/16209762.html
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人