I don't know what|

Sheldon2

园龄:3年3个月粉丝:2关注:4

Pytorch基于MNIST数据集简单实现手写数字识别

"""
模型训练代码
"""
import torch
import torchvision.datasets
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader
import cv2
# 这里我们使用LeNet定义我们的模型
net = nn.Sequential(
nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
nn.Linear(120, 84), nn.Sigmoid(),
nn.Linear(84, 10)
)
train_data = torchvision.datasets.MNIST(root='MNIST',
train=True,
transform=torchvision.transforms.ToTensor(),
download=True
)
test_data = torchvision.datasets.MNIST(root='MNIST',
train=False,
transform=torchvision.transforms.ToTensor(),
download=True
)
train_loader = DataLoader(train_data, batch_size=100, shuffle=True)
test_loader = DataLoader(test_data, batch_size=100, shuffle=True)
# images, lables = next(iter(train_loader))
# img = torchvision.utils.make_grid(images, nrow=10) # 把若干图像拼接成一张图像
# img = img.numpy().transpose(1, 2, 0)
# cv2.imshow('img', img)
# cv2.waitKey(0)
# for data in train_loader:
# imgs, target = data
# # print(imgs.shape)
# # print(target.shape)
# print(target)
# # print(data[0].shape) # (100 , 1, 28, 28) 100个皮偏高 1个通道 28 * 28 的图像
# break
loss = nn.CrossEntropyLoss() # 损失函数
optim = torch.optim.Adam(net.parameters(), lr=0.001) # 优化器
num_epochs = 20
for epoch in range(num_epochs):
sum_loss = 0.0
for data in train_loader:
imgs, targets = data
outputs = net(imgs)
result_loss = loss(outputs, targets)
optim.zero_grad() # 梯度清零
result_loss.backward()
optim.step() # 进行优化
sum_loss = sum_loss + result_loss
print(f'epoch:{epoch + 1},训练误差 :{sum_loss/len(train_data)}')
# 测试
net.eval()
test_acc = 0
for data in test_loader:
imgs, targets = data
outputs = net(imgs)
_, id = torch.max(outputs.data, 1) # 1表示维度 返回概率最大的索引
test_acc += torch.sum(id == targets.data)
print("测试误差:%.3f" %((test_acc * 100) / len(test_data)))
# 模型的保存
torch.save(net.state_dict(), "net_parameters.pth")
"""
简易可视化
"""
import torch
import torchvision.datasets
from torch import nn
from d2l import torch as d2l
from torchvision import transforms
from torch.utils.data import DataLoader
import cv2
net = nn.Sequential(
nn.Conv2d(1, 6, kernel_size=5, padding=2),nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
nn.Linear(120, 84), nn.Sigmoid(),
nn.Linear(84, 10)
)
net.load_state_dict(torch.load(r"D:\PycharmProjects\pytorch_study\easy_test\net_parameters.pth"))
test_data = torchvision.datasets.MNIST(root='MNIST',
train=False,
transform=torchvision.transforms.ToTensor(),
download=True
)
test_loader = DataLoader(test_data, batch_size=1, shuffle=True)
"""
下面这里只是一个简单的可视化,读者可自行优化
"""
for data in test_loader:
imgs, targets = data
output = net(imgs)
print(torch.topk(output, 1)[1].squeeze(0))
img = imgs.numpy().reshape((28, 28))
cv2.imshow('img', img)
cv2.waitKey(0)
break

本文作者:Shedlon2

本文链接:https://www.cnblogs.com/Sheldon2/p/16906029.html

版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。

posted @   Sheldon2  阅读(63)  评论(0编辑  收藏  举报
点击右上角即可分享
微信分享提示
评论
收藏
关注
推荐
深色
回顶
收起