Fork me on github

Pytorch卷积神经网络对MNIST数据集的手写数字识别

这个程序由两个文件组成,一个训练脚本,一个测试脚本。安装好相应依赖环境之后即可进行训练,MNIST数据集使用torchvision.datasets.mnist包自动下载。

mnistTrain.py

# -*- coding: utf-8 -*-
import torch
from torchvision.datasets.mnist import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from multiprocessing import cpu_count
from tqdm import tqdm
EPOCHS = 25 # 训练轮数
BATCH_SIZE = 64 # 每组数据多少张图片
DATA_FOLDER = 'dataset' # 数据集保存目录
MODEL_FILE = 'MNIST_CNN.pkl' # 模型文件路径
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class CNN(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, 32, kernel_size=5, padding=2),
torch.nn.BatchNorm2d(32),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2)
)
self.fc = torch.nn.Linear(14 * 14 * 32, 10)
def forward(self, feature: torch.Tensor) -> torch.Tensor:
out: torch.Tensor = self.conv(feature)
out = out.flatten(1)
out = self.fc(out)
return out
if __name__ == '__main__':
torch.set_num_threads(cpu_count())
trainData = MNIST(DATA_FOLDER, train=True, transform=ToTensor(), download=True)
testData = MNIST(DATA_FOLDER, train=False, transform=ToTensor(), download=True)
trainLoader = DataLoader(trainData, batch_size=BATCH_SIZE, shuffle=True)
testLoader = DataLoader(testData, batch_size=128, shuffle=True)
cnn = CNN().to(DEVICE)
lossFunc = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=0.005)
bestAccuracy = 0
for epoch in range(EPOCHS):
# Train
for images, labels in tqdm(trainLoader, desc=f'Epoch {epoch + 1}/{EPOCHS}'):
images: torch.Tensor = images.to(DEVICE)
labels: torch.Tensor = labels.to(DEVICE)
predictions: torch.Tensor = cnn(images)
loss: torch.Tensor = lossFunc(predictions, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
accuracy = 0
for images, labels in testLoader:
images: torch.Tensor = images.to(DEVICE)
labels: torch.Tensor = labels.to(DEVICE)
predictions: torch.Tensor = cnn(images)
pred: torch.Tensor = predictions.max(dim=1)[1]
accuracy += (pred == labels).sum().item()
accuracy /= len(testData.targets)
if bestAccuracy < accuracy:
bestAccuracy = accuracy
torch.save(cnn, MODEL_FILE)
print(f'Accuracy: {accuracy * 100}% Best Accuracy: {bestAccuracy * 100}%')

mnistTest.py

# -*- coding: utf-8 -*-
from mnistTrain import CNN, BATCH_SIZE, DATA_FOLDER, DEVICE, MODEL_FILE
import torch
from torchvision.datasets.mnist import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from tqdm import tqdm
if __name__ == '__main__':
testData = MNIST(DATA_FOLDER, train=False, transform=ToTensor(), download=True)
testLoader = DataLoader(testData, batch_size=BATCH_SIZE, shuffle=True)
cnn: CNN = torch.load(MODEL_FILE).to(DEVICE)
accuracy = 0
for images, labels in tqdm(testLoader):
images: torch.Tensor = images.to(DEVICE)
labels: torch.Tensor = labels.to(DEVICE)
predictions: torch.Tensor = cnn.forward(images)
pred: torch.Tensor = predictions.max(dim=1)[1]
accuracy += (pred == labels).sum().item()
accuracy /= len(testData.targets)
print(f'Accuracy: {accuracy * 100}%')
posted @   fang-d  阅读(166)  评论(1编辑  收藏  举报
相关博文:
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
点击右上角即可分享
微信分享提示