Pytorch基于MNIST数据集简单实现手写数字识别
""" 模型训练代码 """ import torch import torchvision.datasets from torch import nn from torchvision import transforms from torch.utils.data import DataLoader import cv2 # 这里我们使用LeNet定义我们的模型 net = nn.Sequential( nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2), nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(), nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(), nn.Linear(120, 84), nn.Sigmoid(), nn.Linear(84, 10) ) train_data = torchvision.datasets.MNIST(root='MNIST', train=True, transform=torchvision.transforms.ToTensor(), download=True ) test_data = torchvision.datasets.MNIST(root='MNIST', train=False, transform=torchvision.transforms.ToTensor(), download=True ) train_loader = DataLoader(train_data, batch_size=100, shuffle=True) test_loader = DataLoader(test_data, batch_size=100, shuffle=True) # images, lables = next(iter(train_loader)) # img = torchvision.utils.make_grid(images, nrow=10) # 把若干图像拼接成一张图像 # img = img.numpy().transpose(1, 2, 0) # cv2.imshow('img', img) # cv2.waitKey(0) # for data in train_loader: # imgs, target = data # # print(imgs.shape) # # print(target.shape) # print(target) # # print(data[0].shape) # (100 , 1, 28, 28) 100个皮偏高 1个通道 28 * 28 的图像 # break loss = nn.CrossEntropyLoss() # 损失函数 optim = torch.optim.Adam(net.parameters(), lr=0.001) # 优化器 num_epochs = 20 for epoch in range(num_epochs): sum_loss = 0.0 for data in train_loader: imgs, targets = data outputs = net(imgs) result_loss = loss(outputs, targets) optim.zero_grad() # 梯度清零 result_loss.backward() optim.step() # 进行优化 sum_loss = sum_loss + result_loss print(f'epoch:{epoch + 1},训练误差 :{sum_loss/len(train_data)}') # 测试 net.eval() test_acc = 0 for data in test_loader: imgs, targets = data outputs = net(imgs) _, id = torch.max(outputs.data, 1) # 1表示维度 返回概率最大的索引 test_acc += torch.sum(id == targets.data) print("测试误差:%.3f" %((test_acc * 100) / len(test_data))) # 模型的保存 torch.save(net.state_dict(), "net_parameters.pth")
""" 简易可视化 """ import torch import torchvision.datasets from torch import nn from d2l import torch as d2l from torchvision import transforms from torch.utils.data import DataLoader import cv2 net = nn.Sequential( nn.Conv2d(1, 6, kernel_size=5, padding=2),nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2), nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(), nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(), nn.Linear(120, 84), nn.Sigmoid(), nn.Linear(84, 10) ) net.load_state_dict(torch.load(r"D:\PycharmProjects\pytorch_study\easy_test\net_parameters.pth")) test_data = torchvision.datasets.MNIST(root='MNIST', train=False, transform=torchvision.transforms.ToTensor(), download=True ) test_loader = DataLoader(test_data, batch_size=1, shuffle=True) """ 下面这里只是一个简单的可视化,读者可自行优化 """ for data in test_loader: imgs, targets = data output = net(imgs) print(torch.topk(output, 1)[1].squeeze(0)) img = imgs.numpy().reshape((28, 28)) cv2.imshow('img', img) cv2.waitKey(0) break
AI大三在读
本文作者:Shedlon2
本文链接:https://www.cnblogs.com/Sheldon2/p/16906029.html
版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步