注意力模型CBAM分类-pytorch

目前因项目需要,将检测模型与图像分类结合,完成项目。因此将CBAM模型代码进行整理,仅仅需要train.py与test.py,可分别对图像训练与分类,为了更好学习代码,本文内容分2块,其一将引用

他人博客,简单介绍原理;其二根据改写代码,介绍如何使用,训练自己模型及测试图片。论文:CBAM: Convolutional Block Attention Module 

 代码可参考:https://github.com/tangjunjun966/CBAM_PyTorch

一.基本原理

Convolutional Block Attention Module (CBAM) 表示卷积模块的注意力机制模块。是一种结合了空间(spatial)和通道(channel)的注意力机制模块。相比于senet只关注通道(channel)的注意力机制可以取得更好的效果。

基于传统VGG结构的CBAM模块。需要在每个卷积层后面加该模块。

 

基于shortcut结构的CBAM模块。例如resnet50,该模块在每个resnet的block后面加该模块。

 

 

 

 

Channel attention module:

 

 

 

 

将输入的featuremap,分别经过基于width和height的global max pooling 和global average pooling,然后分别经过MLP。将MLP输出的特征进行基于elementwise的加和操作,再经过sigmoid激活操作,生成最终的channel attention featuremap。将该channel attention featuremap和input featuremap做elementwise乘法操作,生成Spatial attention模块需要的输入特征。

 

 

 

其中,seigema为sigmoid操作,r表示减少率,其中W0后面需要接RELU激活。

 

Spatial attention module:

 

 

 

 

将Channel attention模块输出的特征图作为本模块的输入特征图。首先做一个基于channel的global max pooling 和global average pooling,然后将这2个结果基于channel 做concat操作。然后经过一个卷积操作,降维为1个channel。再经过sigmoid生成spatial attention feature。最后将该feature和该模块的输入feature做乘法,得到最终生成的特征。

 

 

 

其中,seigema为sigmoid操作,7*7表示卷积核的大小,7*7的卷积核比3*3的卷积核效果更好。

二.代码使用

复制代码存放文件夹,其格式如下:

 

 

 

 

训练代码,已将整理成数据产生,模型产生等,可复制后修改args内参数,可直接调用。

训练代码如下:


from collections import OrderedDict
import argparse
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import transforms, models, datasets
from torchnet.meter import ClassErrorMeter, ConfusionMeter
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import traceback
import os
import time
import torch
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
import sys
from PIL import Image
import numpy as np

def load_state_dict(model_dir, is_multi_gpu):
    state_dict = torch.load(model_dir, map_location=lambda storage, loc: storage)['state_dict']
    if is_multi_gpu:
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
        return new_state_dict
    else:
        return state_dict





def parse_parameters():
    parser = argparse.ArgumentParser(description='PyTorch Template')
    parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint (default: None)') # 基本不适用
    parser.add_argument('--debug', action='store_true', dest='debug', help='trainer debug flag') # 不适用
    parser.add_argument('--gpu', default='0', type=str, help='GPU ID Select')  # 多gpu使用:'0,1,2'
    parser.add_argument('--data_root', default='./datasets/', type=str, help='data root') # datasets下面包含train与val文件夹,其中train与val文件夹内存放缺陷文件夹(缺陷图片)具体路径可看代码
    parser.add_argument('--train_file', default='./datasets//train.txt', type=str, help='train file')
    parser.add_argument('--val_file', default='./datasets/val.txt',
                        type=str, help='validation file')
    parser.add_argument('--model', default='resnet50_cbam', type=str, help='model type')
    parser.add_argument('--batch_size', default=4, type=int, help='model train batch size')
    parser.add_argument('--display', action='store_true', dest='display', help='Use TensorboardX to Display')
    parser.add_argument('--classes', default=2, type=int, help='Number of classes')
    parser.add_argument('--work_dir', default='./datasets/work_dir', type=str, help='work directory')
    parser.add_argument('--total_epochs', default=36, type=int, help='total epoch')


    args = parser.parse_args()
    return args


class Logger(object):
    '''Save training process to log file with simple plot function.'''
    def __init__(self, fpath, resume=False):
        self.file = None
        self.resume = resume
        if os.path.isfile(fpath):
            if resume:
                self.file = open(fpath, 'a')
            else:
                self.file = open(fpath, 'w')
        else:
            self.file = open(fpath, 'w')

    def append(self, target_str):
        if not isinstance(target_str, str):
            try:
                target_str = str(target_str)
            except:
                traceback.print_exc()
            else:
                print(target_str)
                self.file.write(target_str + '\n')
                self.file.flush()
        else:
            print(target_str)
            self.file.write(target_str + '\n')
            self.file.flush()

    def close(self):
        if self.file is not None:
            self.file.close()




class Concat_patch(object):
    """Resize the input PIL Image to the given size.

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (h, w), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
    """

    def __init__(self, margin_ratio=(0.25, 0.25)):
        self.margin_ratio = margin_ratio

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be scaled.

        Returns:
            PIL Image: Rescaled image.
        """
        img = img
        array_img = np.array(img)
        h, w, c = array_img.shape
        h_margin = int(h * self.margin_ratio[0])
        w_margin = int(w * self.margin_ratio[1])
        patches = [array_img[0:h_margin, 0:w_margin, :], array_img[h - h_margin:, 0:w_margin, :],
                   array_img[0:h_margin, w - w_margin:, :], array_img[h - h_margin:, w - w_margin:, :]]

        def concat_patches(patches):
            a = np.concatenate(patches[:2], axis=0)
            b = np.concatenate(patches[2:], axis=0)
            c = np.concatenate([a, b], axis=1)
            return c

        img = concat_patches(patches)
        img = Image.fromarray(img)
        return img

    def __repr__(self):
        interpolate_str = 'reconcat'
        return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)


def build_dataset(args):
    gpus = args.gpu.split(',')
    data_transforms = {
        'train': transforms.Compose([
            Concat_patch(),
            transforms.Resize((224, 224)),
            # transforms.Resize((320, 320)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            # transforms.RandomRotation(90),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            Concat_patch(),
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    }
    train_datasets = datasets.ImageFolder(os.path.join(args.data_root, 'train'), data_transforms['train'])
    val_datasets = datasets.ImageFolder(os.path.join(args.data_root, 'val'), data_transforms['val'])
    # sampler = torch.utils.data.WeightedRandomSampler(weights=[1, 1], num_samples=len(train_datasets), replacement=True)
    train_dataloaders = torch.utils.data.DataLoader(train_datasets, batch_size=args.batch_size * len(gpus),
                                                    shuffle=True, num_workers=4)
    val_dataloaders = torch.utils.data.DataLoader(val_datasets, batch_size=4, shuffle=False, num_workers=4)

    return train_dataloaders,val_dataloaders

def build_model(args):
    if 'resnet50' == args.model:
        my_model = resnet50(pretrained=False, num_classes=args.classes)
    elif 'resnet50_cbam' == args.model:
        my_model = resnet50_cbam(pretrained=False, num_classes=args.classes)
    elif 'resnet101' == args.model:
        my_model = models.resnet101(pretrained=False, num_classes=args.classes)
    elif 'resnet18' == args.model:
        my_model = models.resnet18(pretrained=False, num_classes=args.classes)
    elif 'resnet18_cbam' == args.model:
        my_model = resnet18_cbam(pretrained=True, num_classes=args.classes)
    else:
        raise ModuleNotFoundError


    return my_model



def build_optimezer(model):
    loss_fn = [nn.CrossEntropyLoss(weight=torch.Tensor([0.5, 5]).cuda())]  # 不放到其它cuda上,是因为model输出结果在cuda0上处理
    # loss_fn = [nn.CrossEntropyLoss()]
    optimizer = optim.SGD(model.parameters(), lr=0.02, momentum=0.9, weight_decay=1e-4)
    lr_schedule = lr_scheduler.MultiStepLR(optimizer, milestones=[16, 24, 32], gamma=0.1)  # 按照epoch更新lr
    return loss_fn,optimizer,lr_schedule



class Trainer():
    def __init__(self, model, model_type, loss_fn, optimizer, lr_schedule, log_batchs, is_use_cuda, train_data_loader, \
                 valid_data_loader=None, metric=None, start_epoch=0, num_epochs=25, is_debug=False, logger=None,
                  workdir='.'):
        self.model = model
        self.model_type = model_type
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.lr_schedule = lr_schedule
        self.log_batchs = log_batchs
        self.is_use_cuda = is_use_cuda
        self.train_data_loader = train_data_loader
        self.valid_data_loader = valid_data_loader
        self.metric = metric
        self.start_epoch = start_epoch
        self.num_epochs = num_epochs
        self.is_debug = is_debug

        self.cur_epoch = start_epoch
        self.best_acc = 0.
        self.best_loss = sys.float_info.max
        self.logger = logger

        self.workdir = workdir

    def fit(self):
        for epoch in range(0, self.start_epoch):
            self.lr_schedule.step()

        for epoch in range(self.start_epoch, self.num_epochs):
            self.logger.append('Epoch {}/{}'.format(epoch, self.num_epochs - 1))
            self.logger.append('-' * 60)
            self.cur_epoch = epoch
            self.lr_schedule.step() # 实际更新scheduler.last_epoch,且当该值到milestones,则改变学习率
            if self.is_debug:
                self._dump_infos()
            self._train()
            self._valid()
            self._save_best_model()
            print()

    def _dump_infos(self):
        self.logger.append('---------------------Current Parameters---------------------')
        self.logger.append('is use GPU: ' + ('True' if self.is_use_cuda else 'False'))
        self.logger.append('lr: %f' % (self.lr_schedule.get_lr()[0]))
        self.logger.append('model_type: %s' % (self.model_type))
        self.logger.append('current epoch: %d' % (self.cur_epoch))
        self.logger.append('best accuracy: %f' % (self.best_acc))
        self.logger.append('best loss: %f' % (self.best_loss))
        self.logger.append('------------------------------------------------------------')

    def _train(self):
        self.model.train()  # Set model to training mode
        losses = []
        if self.metric is not None:
            self.metric[0].reset()
            # self.metric[1].reset()

        for i, (inputs, labels) in enumerate(self.train_data_loader):  # Notice
            if self.is_use_cuda:
                inputs, labels = inputs.cuda(), labels.cuda()
                labels = labels.squeeze()
            else:
                labels = labels.squeeze()

            self.optimizer.zero_grad()  # 清理梯度
            outputs = self.model(inputs)  # Notice
            loss = self.loss_fn[0](outputs, labels)
            if self.metric is not None:
                prob = F.softmax(outputs, dim=1).data.cpu()
                self.metric[0].add(prob, labels.data.cpu())

                #one_hot = torch.zeros(prob.shape[0], prob.shape[1]).scatter_(1, labels.cpu(), 1)
                # self.metric[1].add(prob, labels.data.cpu())
            loss.backward()
            self.optimizer.step()

            losses.append(loss.item())  # Notice
            if 0 == i % self.log_batchs or (i == len(self.train_data_loader) - 1):
                local_time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
                batch_mean_loss = np.mean(losses)
                print_str = '[%s]\tTraining Batch[%d/%d]\t Class Loss: %.4f\t' \
                            % (local_time_str, i, len(self.train_data_loader) - 1, batch_mean_loss)
                if i == len(self.train_data_loader) - 1 and self.metric is not None:
                    confusion = self.metric[0].value()
                    print(confusion)
                    # top1_acc_score = self.metric[0].value()[0]
                    # top3_acc_score = self.metric[0].value()[1]
                    # print_str += '@Top-1 Score: %.4f\t' % (top1_acc_score)
                    # print_str += '@Top-3 Score: %.4f\t' % (top3_acc_score)
                    # print(self.metric[1].value())
                self.logger.append(print_str)



    def _valid(self):
        self.model.eval()
        losses = []
        acc_rate = 0.
        if self.metric is not None:
            self.metric[0].reset()

        with torch.no_grad():  # Notice
            for i, (inputs, labels) in enumerate(self.valid_data_loader):
                if self.is_use_cuda:
                    inputs, labels = inputs.cuda(), labels.cuda()
                    labels = labels.squeeze()
                else:
                    labels = labels.squeeze()

                if len(labels.shape) == 0:
                    labels = labels.view(-1)
                outputs = self.model(inputs)  # Notice
                loss = self.loss_fn[0](outputs, labels)

                if self.metric is not None:
                    prob = F.softmax(outputs, dim=1).data.cpu()
                    self.metric[0].add(prob, labels.data.cpu())
                losses.append(loss.item())

        local_time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
        # self.logger.append(losses)
        batch_mean_loss = np.mean(losses)
        print_str = '[%s]\tValidation: \t Class Loss: %.4f\t' \
                    % (local_time_str, batch_mean_loss)
        if self.metric is not None:
            confusion = self.metric[0].value()
            print(confusion)
            # top1_acc_score = self.metric[0].value()[0]
            # top3_acc_score = self.metric[0].value()[1]
            # print_str += '@Top-1 Score: %.4f\t' % (top1_acc_score)
            # print_str += '@Top-3 Score: %.4f\t' % (top3_acc_score)
        self.logger.append(print_str)
        # if top1_acc_score >= self.best_acc:
        #     self.best_acc = top1_acc_score
        #     self.best_loss = batch_mean_loss

    def _save_best_model(self):
        # Save Model
        self.logger.append('Saving Model...')
        state = {
            'state_dict': self.model.state_dict(),
            'best_acc': self.best_acc,
            'cur_epoch': self.cur_epoch,
            'num_epochs': self.num_epochs
        }
        if not os.path.isdir(os.path.join(self.workdir, 'checkpoint/') + self.model_type):
            os.makedirs(os.path.join(self.workdir, 'checkpoint/') + self.model_type)
        torch.save(state,
                   os.path.join(self.workdir, 'checkpoint/') + self.model_type + '/Models' + '_epoch_%d' % self.cur_epoch + '.ckpt')  # Notice


# 构建网络





model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)  # _索引,维度不变
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

        self.ca = ChannelAttention(planes)
        self.sa = SpatialAttention()

        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out = self.ca(out) * out
        out = self.sa(out) * out

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class Bottleneck_CBAM(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck_CBAM, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)

        self.ca = ChannelAttention(planes * 4)
        self.sa = SpatialAttention()

        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out = self.ca(out) * out
        out = self.sa(out) * out

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)

        # self.ca = ChannelAttention(planes * 4)
        # self.sa = SpatialAttention()

        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        # out = self.ca(out) * out
        # out = self.sa(out) * out

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out
class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=23):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)

        self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x
def resnet18_cbam(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        pretrained_state_dict = model_zoo.load_url(model_urls['resnet18'])
        now_state_dict = model.state_dict()
        now_state_dict.update(pretrained_state_dict)
        now_state_dict.pop('fc.weight')
        now_state_dict.pop('fc.bias')
        model.load_state_dict(now_state_dict, strict=False)
    return model

def resnet34_cbam(pretrained=False, **kwargs):
    """Constructs a ResNet-34 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
    if pretrained:
        pretrained_state_dict = model_zoo.load_url(model_urls['resnet34'])
        now_state_dict = model.state_dict()
        now_state_dict.update(pretrained_state_dict)
        model.load_state_dict(now_state_dict)
    return model

def resnet50_cbam(pretrained=False, **kwargs):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck_CBAM, [3, 4, 6, 3], **kwargs)
    if pretrained:
        pretrained_state_dict = model_zoo.load_url(model_urls['resnet50'])
        now_state_dict = model.state_dict()
        now_state_dict.update(pretrained_state_dict)
        model.load_state_dict(now_state_dict)
    return model

def resnet50(pretrained=False, **kwargs):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        pretrained_state_dict = model_zoo.load_url(model_urls['resnet50'])
        now_state_dict = model.state_dict()
        now_state_dict.update(pretrained_state_dict)
        model.load_state_dict(now_state_dict)
    return model

def resnet101_cbam(pretrained=False, **kwargs):
    """Constructs a ResNet-101 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck_CBAM, [3, 4, 23, 3], **kwargs)
    if pretrained:
        pretrained_state_dict = model_zoo.load_url(model_urls['resnet101'])
        now_state_dict = model.state_dict()
        now_state_dict.update(pretrained_state_dict)
        model.load_state_dict(now_state_dict)
    return model

def resnet152_cbam(pretrained=False, **kwargs):
    """Constructs a ResNet-152 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck_CBAM, [3, 8, 36, 3], **kwargs)
    if pretrained:
        pretrained_state_dict = model_zoo.load_url(model_urls['resnet152'])
        now_state_dict = model.state_dict()
        now_state_dict.update(pretrained_state_dict)
        model.load_state_dict(now_state_dict)
    return model


def train():
    args=parse_parameters()
    logger = Logger('./' + args.model + '.log') if len(args.resume)==0 else Logger('./' + args.model + '.log', True)
    logger.append(vars(args))

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    is_use_cuda = torch.cuda.is_available()
    cudnn.benchmark = True

    train_dataloaders, val_dataloaders = build_dataset(args)
    model=build_model(args)
    loss_fn, optimizer, lr_schedule = build_optimezer(model)


    if is_use_cuda and 1 == len(args.gpu.split(',')):
        model = model.cuda()
    elif is_use_cuda and 1 < len(args.gpu.split(',')):
        model = nn.DataParallel(model.cuda())   # 将模型my_model.cuda() 缓存放在cuda 0 上


    metric = [ConfusionMeter(2)]
    start_epoch = 0
    my_trainer = Trainer(model, args.model, loss_fn, optimizer, lr_schedule, 10, is_use_cuda, train_dataloaders,
                         val_dataloaders, metric, start_epoch, args.total_epochs, args.debug, logger,  args.work_dir)
    my_trainer.fit()
    logger.append('Optimize Done!')









if __name__ == '__main__':


    train()

 



测试代码调用模型依附训练代码,因此需要有训练代码与测试代码同文件,方可调用。
测试代码如下:


from collections import OrderedDict
from PIL import Image
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import transforms
import numpy as np

from train_new import resnet50_cbam

def init_cls_model(checkpoint_path, is_multi_gpu=False, classes=2):

    my_model = resnet50_cbam(num_classes=classes)
    state_dict = torch.load(checkpoint_path)['state_dict']
    if is_multi_gpu:
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
        my_model.load_state_dict(new_state_dict)
    else:
        my_model.load_state_dict(state_dict)

    my_model = my_model.cuda()
    my_model.eval()

    return my_model

class Concat_patch(object):  # 切图,实际可以不用
    """Resize the input PIL Image to the given size.
    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (h, w), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
    """

    def __init__(self, margin_ratio=(0.25, 0.25)):
        self.margin_ratio = margin_ratio

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be scaled.

        Returns:
            PIL Image: Rescaled image.
        """
        img = img
        array_img = np.array(img)
        h, w, c = array_img.shape
        h_margin = int(h * self.margin_ratio[0])
        w_margin = int(w * self.margin_ratio[1])
        patches = [array_img[0:h_margin, 0:w_margin, :], array_img[h - h_margin:, 0:w_margin, :],
                   array_img[0:h_margin, w - w_margin:, :], array_img[h - h_margin:, w - w_margin:, :]]

        def concat_patches(patches):
            a = np.concatenate(patches[:2], axis=0)
            b = np.concatenate(patches[2:], axis=0)
            c = np.concatenate([a, b], axis=1)
            return c

        img = concat_patches(patches)
        img = Image.fromarray(img)
        return img

    def __repr__(self):
        interpolate_str = 'reconcat'
        return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)

def cls_judge(img_path, model, img_size=224):
    FALSE_NAME = 'FALSE'
    NG_NAME = 'NG'

    CLS_NAME = [FALSE_NAME, NG_NAME]
    data_transform = transforms.Compose([
        Concat_patch(),
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])



    file_path = img_path

    with torch.no_grad():
        img_tensor = data_transform(Image.open(file_path).convert('RGB')).unsqueeze(0)
        img_tensor = Variable(img_tensor.cuda(), volatile=True)
        output = F.softmax(model(img_tensor), dim=1).cpu().numpy()
    # defect_prob = round(output.data[0, 1], 6)
    pred = np.argmax(output)
    pred = CLS_NAME[pred]

    score = np.max(output)
    if pred == FALSE_NAME:
        score = 0
    if score <= 0.85 and pred == NG_NAME:
        pred = FALSE_NAME
        score = 0

    return pred, score


if __name__ == '__main__':
    model_path=r'E:\code_tj\CBAM_PyTorch\datasets\work_dir\checkpoint\resnet50-cbam\Models_epoch_0.ckpt'
    img=r'E:\code_tj\CBAM_PyTorch\datasets\val\v06\W0C2P0206A0108_WHITE_20210125.jpg'
    model=init_cls_model(model_path, is_multi_gpu=False, classes=2)
    pre=cls_judge(img, model, img_size=224)
    print(pre)
 

 

 

 

 

 

 


参考博客:https://blog.csdn.net/qq_14845119/article/details/81393127

 

posted @ 2021-06-09 21:24  tangjunjun  阅读(2544)  评论(0编辑  收藏  举报
https://rpc.cnblogs.com/metaweblog/tangjunjun