Pytorch卷积神经网络对MNIST数据集的手写数字识别
这个程序由两个文件组成,一个训练脚本,一个测试脚本。安装好相应依赖环境之后即可进行训练,MNIST数据集使用torchvision.datasets.mnist
包自动下载。
mnistTrain.py
# -*- coding: utf-8 -*- import torch from torchvision.datasets.mnist import MNIST from torchvision.transforms import ToTensor from torch.utils.data import DataLoader from multiprocessing import cpu_count from tqdm import tqdm EPOCHS = 25 # 训练轮数 BATCH_SIZE = 64 # 每组数据多少张图片 DATA_FOLDER = 'dataset' # 数据集保存目录 MODEL_FILE = 'MNIST_CNN.pkl' # 模型文件路径 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class CNN(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Sequential( torch.nn.Conv2d(1, 32, kernel_size=5, padding=2), torch.nn.BatchNorm2d(32), torch.nn.ReLU(), torch.nn.MaxPool2d(2) ) self.fc = torch.nn.Linear(14 * 14 * 32, 10) def forward(self, feature: torch.Tensor) -> torch.Tensor: out: torch.Tensor = self.conv(feature) out = out.flatten(1) out = self.fc(out) return out if __name__ == '__main__': torch.set_num_threads(cpu_count()) trainData = MNIST(DATA_FOLDER, train=True, transform=ToTensor(), download=True) testData = MNIST(DATA_FOLDER, train=False, transform=ToTensor(), download=True) trainLoader = DataLoader(trainData, batch_size=BATCH_SIZE, shuffle=True) testLoader = DataLoader(testData, batch_size=128, shuffle=True) cnn = CNN().to(DEVICE) lossFunc = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(cnn.parameters(), lr=0.005) bestAccuracy = 0 for epoch in range(EPOCHS): # Train for images, labels in tqdm(trainLoader, desc=f'Epoch {epoch + 1}/{EPOCHS}'): images: torch.Tensor = images.to(DEVICE) labels: torch.Tensor = labels.to(DEVICE) predictions: torch.Tensor = cnn(images) loss: torch.Tensor = lossFunc(predictions, labels) optimizer.zero_grad() loss.backward() optimizer.step() accuracy = 0 for images, labels in testLoader: images: torch.Tensor = images.to(DEVICE) labels: torch.Tensor = labels.to(DEVICE) predictions: torch.Tensor = cnn(images) pred: torch.Tensor = predictions.max(dim=1)[1] accuracy += (pred == labels).sum().item() accuracy /= len(testData.targets) if bestAccuracy < accuracy: bestAccuracy = accuracy torch.save(cnn, MODEL_FILE) print(f'Accuracy: {accuracy * 100}% Best Accuracy: {bestAccuracy * 100}%')
mnistTest.py
# -*- coding: utf-8 -*- from mnistTrain import CNN, BATCH_SIZE, DATA_FOLDER, DEVICE, MODEL_FILE import torch from torchvision.datasets.mnist import MNIST from torchvision.transforms import ToTensor from torch.utils.data import DataLoader from tqdm import tqdm if __name__ == '__main__': testData = MNIST(DATA_FOLDER, train=False, transform=ToTensor(), download=True) testLoader = DataLoader(testData, batch_size=BATCH_SIZE, shuffle=True) cnn: CNN = torch.load(MODEL_FILE).to(DEVICE) accuracy = 0 for images, labels in tqdm(testLoader): images: torch.Tensor = images.to(DEVICE) labels: torch.Tensor = labels.to(DEVICE) predictions: torch.Tensor = cnn.forward(images) pred: torch.Tensor = predictions.max(dim=1)[1] accuracy += (pred == labels).sum().item() accuracy /= len(testData.targets) print(f'Accuracy: {accuracy * 100}%')
本文版权,除注明引用的部分外,归作者所有。本文严禁商业用途的转载。非商业用途的转载需在网页明显处署上作者名称及原文链接。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了