简单易懂Pytorch实战实例VGG深度网络

模型VGG,数据集cifar。对照这份代码走一遍,大概就知道整个pytorch的运行机制。

来源

定义模型:

    '''VGG11/13/16/19 in Pytorch.'''
    import torch
    import torch.nn as nn
    from torch.autograd import Variable
    
    
    cfg = {
      'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
      'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
      'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
      'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
    }
    
    # 模型需继承nn.Module
    class VGG(nn.Module):
    # 初始化参数:
      def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, 10)
    
    # 模型计算时的前向过程,也就是按照这个过程进行计算
      def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out
    
      def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
          if x == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
          else:
            layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                  nn.BatchNorm2d(x),
                  nn.ReLU(inplace=True)]
            in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)
    
    # net = VGG('VGG11')
    # x = torch.randn(2,3,32,32)
    # print(net(Variable(x)).size())
    
    

定义训练过程:

    '''Train CIFAR10 with PyTorch.'''
    from __future__ import print_function
    
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torch.nn.functional as F
    import torch.backends.cudnn as cudnn
    
    import torchvision
    import torchvision.transforms as transforms
    
    import os
    import argparse
    
    from models import *
    from utils import progress_bar
    from torch.autograd import Variable
    
    # 获取参数
    parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
    parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
    parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
    args = parser.parse_args()
    
    use_cuda = torch.cuda.is_available()
    best_acc = 0 # best test accuracy
    start_epoch = 0 # start from epoch 0 or last checkpoint epoch
    
    # 获取数据集,并先进行预处理
    print('==> Preparing data..')
    # 图像预处理和增强
    transform_train = transforms.Compose([
      transforms.RandomCrop(32, padding=4),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    transform_test = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
    
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
    
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    
    # 继续训练模型或新建一个模型
    if args.resume:
      # Load checkpoint.
      print('==> Resuming from checkpoint..')
      assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
      checkpoint = torch.load('./checkpoint/ckpt.t7')
      net = checkpoint['net']
      best_acc = checkpoint['acc']
      start_epoch = checkpoint['epoch']
    else:
      print('==> Building model..')
      net = VGG('VGG16')
      # net = ResNet18()
      # net = PreActResNet18()
      # net = GoogLeNet()
      # net = DenseNet121()
      # net = ResNeXt29_2x64d()
      # net = MobileNet()
      # net = MobileNetV2()
      # net = DPN92()
      # net = ShuffleNetG2()
      # net = SENet18()
    
    # 如果GPU可用,使用GPU
    if use_cuda:
      # move param and buffer to GPU
      net.cuda()
      # parallel use GPU
      net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()-1))
      # speed up slightly
      cudnn.benchmark = True
    
    
    # 定义度量和优化
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    
    # 训练阶段
    def train(epoch):
      print('\nEpoch: %d' % epoch)
      # switch to train mode
      net.train()
      train_loss = 0
      correct = 0
      total = 0
      # batch 数据
      for batch_idx, (inputs, targets) in enumerate(trainloader):
        # 将数据移到GPU上
        if use_cuda:
          inputs, targets = inputs.cuda(), targets.cuda()
        # 先将optimizer梯度先置为0
        optimizer.zero_grad()
        # Variable表示该变量属于计算图的一部分,此处是图计算的开始处。图的leaf variable
        inputs, targets = Variable(inputs), Variable(targets)
        # 模型输出
        outputs = net(inputs)
        # 计算loss,图的终点处
        loss = criterion(outputs, targets)
        # 反向传播,计算梯度
        loss.backward()
        # 更新参数
        optimizer.step()
        # 注意如果你想统计loss,切勿直接使用loss相加,而是使用loss.data[0]。因为loss是计算图的一部分,如果你直接加loss,代表total loss同样属于模型一部分,那么图就越来越大
        train_loss += loss.data[0]
        # 数据统计
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()
    
        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
          % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    
    # 测试阶段
    def test(epoch):
      global best_acc
      # 先切到测试模型
      net.eval()
      test_loss = 0
      correct = 0
      total = 0
      for batch_idx, (inputs, targets) in enumerate(testloader):
        if use_cuda:
          inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = Variable(inputs, volatile=True), Variable(targets)
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        # loss is variable , if add it(+=loss) directly, there will be a bigger ang bigger graph.
        test_loss += loss.data[0]
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()
    
        progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
          % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
    
      # Save checkpoint.
      # 保存模型
      acc = 100.*correct/total
      if acc > best_acc:
        print('Saving..')
        state = {
          'net': net.module if use_cuda else net,
          'acc': acc,
          'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
          os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.t7')
        best_acc = acc
    
    # 运行模型
    for epoch in range(start_epoch, start_epoch+200):
      train(epoch)
      test(epoch)
      # 清除部分无用变量 
      torch.cuda.empty_cache()

运行:

新模型:
python main.py --lr=0.01
旧模型继续训练:
python main.py --resume --lr=0.01

一些utility:

    '''Some helper functions for PyTorch, including:
      - get_mean_and_std: calculate the mean and std value of dataset.
      - msr_init: net parameter initialization.
      - progress_bar: progress bar mimic xlua.progress.
    '''
    import os
    import sys
    import time
    import math
    
    import torch.nn as nn
    import torch.nn.init as init
    
    
    def get_mean_and_std(dataset):
      '''Compute the mean and std value of dataset.'''
      dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
      mean = torch.zeros(3)
      std = torch.zeros(3)
      print('==> Computing mean and std..')
      for inputs, targets in dataloader:
        for i in range(3):
          mean[i] += inputs[:,i,:,:].mean()
          std[i] += inputs[:,i,:,:].std()
      mean.div_(len(dataset))
      std.div_(len(dataset))
      return mean, std
    
    def init_params(net):
      '''Init layer parameters.'''
      for m in net.modules():
        if isinstance(m, nn.Conv2d):
          init.kaiming_normal(m.weight, mode='fan_out')
          if m.bias:
            init.constant(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
          init.constant(m.weight, 1)
          init.constant(m.bias, 0)
        elif isinstance(m, nn.Linear):
          init.normal(m.weight, std=1e-3)
          if m.bias:
            init.constant(m.bias, 0)
    
    
    _, term_width = os.popen('stty size', 'r').read().split()
    term_width = int(term_width)
    
    TOTAL_BAR_LENGTH = 65.
    last_time = time.time()
    begin_time = last_time
    def progress_bar(current, total, msg=None):
      global last_time, begin_time
      if current == 0:
        begin_time = time.time() # Reset for new bar.
    
      cur_len = int(TOTAL_BAR_LENGTH*current/total)
      rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
    
      sys.stdout.write(' [')
      for i in range(cur_len):
        sys.stdout.write('=')
      sys.stdout.write('>')
      for i in range(rest_len):
        sys.stdout.write('.')
      sys.stdout.write(']')
    
      cur_time = time.time()
      step_time = cur_time - last_time
      last_time = cur_time
      tot_time = cur_time - begin_time
    
      L = []
      L.append(' Step: %s' % format_time(step_time))
      L.append(' | Tot: %s' % format_time(tot_time))
      if msg:
        L.append(' | ' + msg)
    
      msg = ''.join(L)
      sys.stdout.write(msg)
      for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
        sys.stdout.write(' ')
    
      # Go back to the center of the bar.
      for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
        sys.stdout.write('\b')
      sys.stdout.write(' %d/%d ' % (current+1, total))
    
      if current < total-1:
        sys.stdout.write('\r')
      else:
        sys.stdout.write('\n')
      sys.stdout.flush()
    
    def format_time(seconds):
      days = int(seconds / 3600/24)
      seconds = seconds - days*3600*24
      hours = int(seconds / 3600)
      seconds = seconds - hours*3600
      minutes = int(seconds / 60)
      seconds = seconds - minutes*60
      secondsf = int(seconds)
      seconds = seconds - secondsf
      millis = int(seconds*1000)
    
      f = ''
      i = 1
      if days > 0:
        f += str(days) + 'D'
        i += 1
      if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
      if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
      if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
      if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
      if f == '':
        f = '0ms'
      return f

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

在这里插入图片描述

posted @ 2021-06-16 15:28  老酱  阅读(522)  评论(0编辑  收藏  举报