CNN实现手写数字识别
全部代码如下:
import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torchvision import datasets, transforms # 超参数 batch_size = 64 epochs = 10 learning_rate = 0.01 momentum = 0.5 log_interval = 10 # 准备数据 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_dataset = datasets.MNIST('data', train=False, transform=transform) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True) # 定义模型 class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.fc1 = nn.Linear(320, 500) self.fc2 = nn.Linear(500, 10) def forward(self, x): x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = F.relu(F.max_pool2d(self.conv2(x), 2)) x = x.view(-1, 320) x = F.relu(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x, dim=1) model = Net() optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum) # 训练模型 def train(epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, target) loss.backward() optimizer.step() if batch_idx % log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 0. * batch_idx / len(train_loader), loss.item())) # 测试模型 def test(): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: output = model(data) test_loss += F.nll_loss(output, target, size_average=False).item() pred = output.data.max(1, keepdim=True)[1] correct += pred.eq(target.data.view_as(pred)).sum() test_loss /= len(test_loader.dataset) print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(test_loader.dataset), 0. * correct / len(test_loader.dataset))) for epoch in range(1, epochs + 1): train(epoch) test()
作者:太一吾鱼水
文章未经说明均属原创,学习笔记可能有大段的引用,一般会注明参考文献。
欢迎大家留言交流,转载请注明出处。