Pytorch 搭建 LeNet-5 网络

1 数据集

Mnist 数据集是一个手写数字图片数据集,数据集的下载和解读详见 Mnist数据集解读

这里为了对接 pytorch 的神经网络,需要将数据集制作成可以批量读取的 tensor 数据。采用 torch.utils.data.Dataset 构建。

data.py

import os
import numpy as np
from torch.utils.data import Dataset
import gzip


class Mnist(Dataset):
    def __init__(self, root, train=True, transform=None):

        # 根据是否为训练集,得到文件名前缀
        self.file_pre = 'train' if train == True else 't10k'
        self.transform = transform

        # 生成对应数据集的图片和标签文件路径
        self.label_path = os.path.join(root,
                                       '%s-labels-idx1-ubyte.gz' % self.file_pre)
        self.image_path = os.path.join(root,
                                       '%s-images-idx3-ubyte.gz' % self.file_pre)

        # 读取文件数据,返回图片和标签
        self.images, self.labels = self.__read_data__(
            self.image_path,
            self.label_path)

    def __read_data__(self, image_path, label_path):
        # 数据集读取
        with gzip.open(label_path, 'rb') as lbpath:
            labels = np.frombuffer(lbpath.read(), np.uint8,
                                   offset=8)
        with gzip.open(image_path, 'rb') as imgpath:
            images = np.frombuffer(imgpath.read(), np.uint8,
                                   offset=16).reshape(len(labels), 28, 28)
        return images, labels

    def __getitem__(self, index):
        image, label = self.images[index], int(self.labels[index])

        # 如果需要转成 tensor 则使用 tansform
        if self.transform is not None:
            image = self.transform(np.array(image))  # 此处需要用 np.array(image)
        return image, label

    def __len__(self):
        return len(self.labels)


if __name__ == '__main__':

    # 生成实例
    train_set = Mnist(
        root=r'H:\\Dataset\\Mnist',
        train=False,
    )

    # 取一组数据并展示
    (data, label) = train_set[0]
    import matplotlib.pyplot as plt
    plt.imshow(data.reshape(28, 28), cmap='gray')
    plt.title('label is :{}'.format(label))
    plt.show()

总体思路:指定Mnist数据集的存储路径后,根据是否为训练集,找到对应的压缩包(图像和标签),解压文件并读取数据,利用 Dataset 构造迭代器,从而实现根据索引号返回一组图像和标签的数据。

Dataset 是一个抽象类,需要继承并重写。其中,根据Mnist数据集文件的命名和存储结构,构造了一个__read_data__ 私有函数,用来读取数据,返回图像和标签值;在__init__ 中,初始化数据集,获取到原始的数据;在__getitem__ 中,根据 index ,返回一组图像和标签,这里可以对图像进行变换(可选,例如转成tensor, 归一化等等);在 __len__ 中返回数据集的样本个数。

为了看懂最后输出的内容,生成了一个实例,取出一组数据,并展示,结果如下:
图1.1 从数据集中取出一张图展示

2 模型构建

图2.1 LeNet-5模型架构图

LeNet-5 神经网络一共五层,其中卷积层和池化层可以考虑为一个整体,网络的结构为 :

输入 → 卷积 → 池化 → 卷积 → 池化 → 全连接 → 全连接 → 全连接 → 输出。

pytorch 中,图像数据集的存储顺序为:(batch, channels, height, width),依次为批大小、通道数、高度、宽度。所以,按照网络结构,各层的参数和输入输出关系,可以整理得到下表:

表2.1 LeNet-5模型参数表
操作 操作参数 输入/输出尺寸
input batch: ?
channels: 1
height: 28
width: 28
input:(batch, 1, 28, 28)
output: (batch, 1, 28, 28)
conv1 in_channels: 1
out_channels: 6
kernel_size: 5×5
padding: 0
stride: 1
input: (batch, 1, 28, 28)
output:(batch, 6, 24, 24)
pool1 kernel_size: 2×2 input:(batch, 6, 24, 24)
output:(batch, 6, 12, 12)
conv2 in_channels: 6
out_channels: 16
kernel_size: 5×5
padding: 0
stride: 1
input:(batch, 6, 12, 12)
output:(batch, 16, 8, 8)
pool2 kernel_size: 2×2 input:(batch, 16, 8, 8)
output:(batch, 16, 4, 4)
fc1 in: 16×4×4
out: 120
input:(batch, 16*4*4)
output:(batch, 120)
fc2 in: 120
out: 84
input:(batch, 120)
output:(batch,84)
fc3 in: 84
out:10
input:(batch,84)
output:(batch, 10)

如上表所示,输入的Mnist数据集是灰度图,通道为1,长和宽都为28。经过pytorch处理后,可以生成批量数据,从而多出一个batch的维度数据。

这里需要特别注意的是,从第二次卷积池化后,与全连接层fc1进行数据传递时,是先把池化pool2的输出,除了batch之外的其他维度数据,展平到一个维度。然后送入全连接层。而这个数据的大小跟输入大小有关,因此在设计时,需要仔细推算每一层的输出。

由上面的分析,就可以搭建网络了。

model.py

import torch.nn as nn
import torch.nn.functional as F


class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5,self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(16*4*4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    
    def forward(self,x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 16*4*4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


if __name__ == '__main__':
    net = LeNet5()
    print(net)

主要思路

网络的构建需要继承 torch.nn.Module ,在 _init__ 中和forward 中其实都是可以定义网络的,但是,一般是在__init__ 里定义一些主要的操作,然后在 forward 里输入数据,进行前向传播的表达。其中展平的操作利用 view() 实现,前面的 -1 表示默认,即batch的大小,后面则是其余维度展平后的大小。

为了看清楚网络的各层参数,将其打印了:

LeNet5(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

3 训练与测试

神经网络的训练主要包括了导入批数据,前向传播,反向传播,权重更新,如此循环迭代。遍历到一定的epoch数量后停止,得到训练好的模型。

随后,将图像送进网络进行测试即可。

main.py

import torch
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data import DataLoader
from data import Mnist
from model import LeNet5


# 生成训练集
train_set = Mnist(
    root=r'H:\\Dataset\\Mnist',
    train=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1037,), (0.3081,))
    ])
)
train_loader = DataLoader(
    dataset=train_set,
    batch_size=32,
    shuffle=True
)


# 实例化一个网络
net = LeNet5()

# 定义损失函数和优化器
loss_function = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(
    net.parameters(),
    lr=0.001,
    momentum=0.9
)

# 3 训练模型
loss_list = []
for epoch in range(10):
    running_loss = 0.0
    for batch_idx, data in enumerate(train_loader, start=0):

        images, labels = data                       # 读取一个batch的数据
        optimizer.zero_grad()                       # 梯度清零,初始化
        outputs = net(images)                       # 前向传播
        loss = loss_function(outputs, labels)       # 计算误差
        loss.backward()                             # 反向传播
        optimizer.step()                            # 权重更新
        running_loss += loss.item()                 # 误差累计

        # 每300个batch 打印一次损失值
        if batch_idx % 300 == 299:
            print('epoch:{} batch_idx:{} loss:{}'
                  .format(epoch+1, batch_idx+1, running_loss/300))
            loss_list.append(running_loss/300)
            running_loss = 0.0                  #误差清零

print('Finished Training.')


# 打印损失值变化曲线
import matplotlib.pyplot as plt
plt.plot(loss_list)
plt.title('traning loss')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.show()


# 测试
test_set = Mnist(
    root='H:\\Dataset\\Mnist',
    train=False,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1037,), (0.3081,))
    ])
)
test_loader = DataLoader(
    dataset=test_set,
    batch_size=32,
    shuffle=True
)

correct = 0  # 预测正确数
total = 0    # 总图片数

for data in test_loader:
    images, labels = data
    outputs = net(images)
    _, predict = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predict == labels).sum()

print('测试集准确率 {}%'.format(100*correct // total))


# 测试自己手动设计的手写数字
from PIL import Image
I = Image.open('8.jpg')
L = I.convert('L')
plt.imshow(L, cmap='gray')

transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1037,), (0.3081,))
])
    
im = transform(L)  # [C, H, W]
im = torch.unsqueeze(im, dim=0)  # [N, C, H, W]

with torch.no_grad():
    outputs = net(im)
    _, predict = torch.max(outputs.data, 1)
    print(predict)

总体思路

利用 torch.utils.data.DataLoader 从数据集中划分批次,然后打乱顺序,每次送入一个批次的数据到神经网络进行训练,每300个批次计算一次损失值。训练结束后,测试了在测试集上的准确率。最后又测试自己手动制作的单一的手写数字图像。

结果如下:
图3.1 loss 迭代曲线

测试集准确率 99%
tensor([8])

图3.2 测试手写图

4 运行界面

图4.1 运行界面
spyder 简直是 matlab转python的绝佳选择!

5 总结

本次LeNet-5 网络的是最基础的,其构建过程是所有其他网络的基本范式。通过这次搭建,我们熟悉了如何导入自己制作的数据集(虽然是数据是网上下载的,但也需要一定的过程转成可用的数据格式);了解网络的搭建方法,分析了其参数和输入输出关系,弄懂了其中卷积池化后与全连接的之间维度上的匹配问题;最后成功地实现了较高的识别准确率。

需要改进的地方:

  1. 模型评估改进,希望生成具体的测试集和训练集损失函数迭代曲线,以及准确率的迭代曲线。

  2. 代码优化,希望将数据集、训练、测试、评价、应用等环节模块化。

posted @ 2020-06-12 14:25  GShang  阅读(5539)  评论(2编辑  收藏  举报