使用Pytorch手动训练VGG11

在前面的博客中我们提到如何用pytorch搭建一个VGG11网络框架;

详见使用Pytorch搭建VGG网络——以VGG11为例

 

在本博客中,我们将使用之前搭建的VGG11网络,同时对其进行手动训练,使我们可以更好的理解模型建立和训练的过程;

主要内容:

  • 数据集和目录结构——使用数字手写辨识来训练VGG11
  • 编码部分:

    1)数据集准备;

    2)训练和验证模型;

    3)优化器

    4)验证每个epoch的精度;

  • 分析训练的损失和精度;
  • 在图像上训练和测试数据

一、数据集和目录结构

数据集:手写体 Digit MNIST

使用torchvision.dataset模块加载;

目录结构:

├── input
│   └── test_data
│       ├── eight.jpg
│       ├── two.jpg
│       └── zero.jpg
├── outputs
│   ├── accuracy.jpg
│   └── loss.jpg
|   ...
├── src
│   ├── data
│   │   └── MNIST
│   ...
│   ├── models.py
│   ├── test.py
│   └── train.py

 

二、 编码部分:

1)  网络模型脚本:VGG11模型的加载,见 使用Pytorch搭建VGG网络——以VGG11为例

2)训练脚本train.py的编写:

2.1)相关包的导入:

 1 import torch
 2 import torchvision
 3 import torchvision.transforms as transforms
 4 import matplotlib.pyplot as plt
 5 import matplotlib
 6 import torch.nn as nn
 7 import torch.optim as optim
 8 from tqdm import tqdm
 9 from models import VGG11
10 matplotlib.style.use('ggplot')

2.2)定义模型的参数及设备:

1 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
2 print(f"[INFO]: Computation device: {device}")
3 epochs = 10
4 batch_size = 32

出现了OOM(out of memory)错误的话,可以减少batch到16、8或者4,来迎合你的GPU;

2.3) 图像变换

 1 # our transforms will differ a bit from the VGG paper
 2 # as we are using the MNIST dataset, so, we will directly resize...
 3 # ... the images to 224x224 and not crop them and we will not use...
 4 # ... any random flippings also
 5 train_transform = transforms.Compose(
 6     [transforms.Resize((224, 224)),
 7      transforms.ToTensor(),
 8      transforms.Normalize(mean=(0.5), std=(0.5))])
 9 valid_transform = transforms.Compose(
10     [transforms.Resize((224, 224)),
11      transforms.ToTensor(),
12      transforms.Normalize(mean=(0.5), std=(0.5))])

注:原始论文中有图像的翻转flipping,在我们的训练中是不需要的;

Line6把图像进行resize()

line7把图像变成tensor

line7标准化处理;

————————————————————————————————

2.4)数据的加载

接下来准备训练和验证数据集以及data loaders。

 1 # training dataset and data loader
 2 train_dataset = torchvision.datasets.MNIST(root='./data', train=True,
 3                                              download=True, 
 4                                              transform=train_transform)
 5 train_dataloader = torch.utils.data.DataLoader(train_dataset, 
 6                                                batch_size=batch_size,
 7                                                shuffle=True)
 8 # validation dataset and dataloader
 9 valid_dataset = torchvision.datasets.MNIST(root='./data', train=False,
10                                            download=True, 
11                                            transform=valid_transform)
12 valid_dataloader = torch.utils.data.DataLoader(valid_dataset, 
13                                              batch_size=batch_size,
14                                              shuffle=False)

2.5)模型的初始化、损失函数以及优化器

 1 # instantiate the model
 2 model = VGG11(in_channels=1, num_classes=10).to(device)
 3 # total parameters and trainable parameters
 4 total_params = sum(p.numel() for p in model.parameters())
 5 print(f"[INFO]: {total_params:,} total parameters.")
 6 total_trainable_params = sum(
 7     p.numel() for p in model.parameters() if p.requires_grad)
 8 print(f"[INFO]: {total_trainable_params:,} trainable parameters.")
 9 # the loss function
10 criterion = nn.CrossEntropyLoss()
11 # the optimizer
12 optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, 
13                       weight_decay=0.0005)
  • 处理的MNIS是灰度图像,所以输入通道为1;输出的类别是10;
  • 使用的是交叉熵损失函数;SGD用于进行参数更新;

2.6)训练函数

 1 # training
 2 def train(model, trainloader, optimizer, criterion):
 3     model.train()
 4     print('Training')
 5     train_running_loss = 0.0
 6     train_running_correct = 0
 7     counter = 0
 8     for i, data in tqdm(enumerate(trainloader), total=len(trainloader)):
 9         counter += 1
10 
11         image, labels = data
12         image = image.to(device)
13         labels = labels.to(device)
14 
15         optimizer.zero_grad()
16         # forward pass
17         outputs = model(image)
18         # calculate the loss
19         loss = criterion(outputs, labels)
20         train_running_loss += loss.item()
21         # calculate the accuracy
22         _, preds = torch.max(outputs.data, 1)
23         train_running_correct += (preds == labels).sum().item()
24         loss.backward()
25         optimizer.step()
26     
27     epoch_loss = train_running_loss / counter
28     epoch_acc = 100. * (train_running_correct / len(trainloader.dataset))
29     return epoch_loss, epoch_acc
  • 迭代训练数据加载器,提取标签和图像;
  • 加载图像和标签到计算设备上;
  • 前向传播,计算损失和精度;后向传播损失;
  • 返回目前epoch的损失和精度;

2.8)验证函数:

验证函数有一些不同。对每个epoch中,我们将会计算损失和精度;

我们将对每个类别计算精度,评估在每个epoch中我们的模型表现力;

 1 # validation
 2 def validate(model, testloader, criterion):
 3     model.eval()
 4     
 5     # we need two lists to keep track of class-wise accuracy
 6     class_correct = list(0. for i in range(10))
 7     class_total = list(0. for i in range(10))
 8     print('Validation')
 9     valid_running_loss = 0.0
10     valid_running_correct = 0
11     counter = 0
12     with torch.no_grad():
13         for i, data in tqdm(enumerate(testloader), total=len(testloader)):
14             counter += 1
15             
16             image, labels = data
17             image = image.to(device)
18             labels = labels.to(device)
19             # forward pass
20             outputs = model(image)
21             # calculate the loss
22             loss = criterion(outputs, labels)
23             valid_running_loss += loss.item()
24             # calculate the accuracy
25             _, preds = torch.max(outputs.data, 1)
26             valid_running_correct += (preds == labels).sum().item()
27             # calculate the accuracy for each class
28             correct  = (preds == labels).squeeze()
29             for i in range(len(preds)):
30                 label = labels[i]
31                 class_correct[label] += correct[i].item()
32                 class_total[label] += 1
33         
34     epoch_loss = valid_running_loss / counter
35     epoch_acc = 100. * (valid_running_correct / len(testloader.dataset))
36     # print the accuracy for each class after evey epoch
37     # the values should increase as the training goes on
38     print('\n')
39     for i in range(10):
40         print(f"Accuracy of digit {i}: {100*class_correct[i]/class_total[i]}")
41     return epoch_loss, epoch_acc

 

2.9)训练过程:

 1 # start the training
 2 # lists to keep track of losses and accuracies
 3 train_loss, valid_loss = [], []
 4 train_acc, valid_acc = [], []
 5 for epoch in range(epochs):
 6     print(f"[INFO]: Epoch {epoch+1} of {epochs}")
 7 
 8     train_epoch_loss, train_epoch_acc = train(model, train_dataloader, 
 9                                               optimizer, criterion)
10     valid_epoch_loss, valid_epoch_acc = validate(model, valid_dataloader,  
11                                                  criterion)
12     train_loss.append(train_epoch_loss)
13     valid_loss.append(valid_epoch_loss)
14     train_acc.append(train_epoch_acc)
15     valid_acc.append(valid_epoch_acc)
16 
17     print('\n')
18     print(f"Training loss: {train_epoch_loss:.3f}, training acc: {train_epoch_acc:.3f}")
19     print(f"Validation loss: {valid_epoch_loss:.3f}, validation acc: {valid_epoch_acc:.3f}")
20 
21     print('-'*50)
  • 把训练和验证的损失和精度存储到train_loss、valid_loss以及train_acc和valid_acc中;
  • 在每个epoch中,我们打印训练和损失的指标;

最后的步骤存储在训练的模型中,绘制出损失和精度图;

 

三、分析训练的损失和精度;

 

 图1 精度图

 

  图2 损失图

 

四、对新的图像进行推断

1 import torch
2 import cv2
3 import glob as glob
4 import torchvision.transforms as transforms
5 import numpy as np
6 from models import VGG11

加载训练的权重,定义Transforms

 1 # inferencing on CPU
 2 device = 'cpu'
 3 # initialize the VGG11 model
 4 model = VGG11(in_channels=1, num_classes=10)
 5 # load the model checkpoint
 6 checkpoint = torch.load('../outputs/model.pth')
 7 # load the trained weights
 8 model.load_state_dict(checkpoint['model_state_dict'])
 9 model.to(device)
10 model.eval()
11 # simple image transforms
12 transform = transforms.Compose([
13     transforms.ToPILImage(),
14     transforms.Resize((224, 224)),
15     transforms.ToTensor(),
16     transforms.Normalize(mean=[0.5],
17                          std=[0.5])
18 ])

读图像并将其传入模型中

 1 # get all the test images path
 2 image_paths = glob.glob('../input/test_data/*.jpg')
 3 for i, image_path in enumerate(image_paths):
 4     orig_img = cv2.imread(image_path)
 5     # convert to grayscale to make the image single channel
 6     image = cv2.cvtColor(orig_img, cv2.COLOR_BGR2GRAY)
 7     image = transform(image)
 8     # add one extra batch dimension
 9     image = image.unsqueeze(0).to(device)
10     # forward pass the image through the model
11     outputs = model(image)
12     # get the index of the highest score
13     # the highest scoring indicates the label for the Digit MNIST dataset
14     label = np.array(outputs.detach()).argmax()
15     print(f"{image_path.split('/')[-1].split('.')[0]}: {label}")
16     # put the predicted label on the original image
17     cv2.putText(orig_img, str(label), (15, 50), cv2.FONT_HERSHEY_SIMPLEX, 
18                 2, (0, 255, 0), 2)
19     # show and save the resutls
20     cv2.imshow('Result', orig_img)
21     cv2.waitKey(0)
22     cv2.imwrite(f"../outputs/result_{i}.jpg", orig_img)

总结:

本文我们主要训练了一个VGG11的神经网络从手写体MNIST数据集中;

我们以模型的初始化、训练模型和观察模型的精度即损失为主线进行了说明;

最后在新的数据集上对模型进行了验证。

 

posted @ 2021-05-21 15:18  hi_mxd  阅读(955)  评论(0编辑  收藏  举报