AI艺术鉴赏代码解析

亚军代码

作者采用了resnest200作为主干网络
采用多次训练模型预测投票的方法
来进行最终分类

数据增强方面

除了常规的方法之外,还有几个操作目前不知道意思
需要进一步研究

以下是部分的transform的写法

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    train_data = TrainDataset(train_data_list,
                              transform=transforms.Compose([
                                  transforms.Resize((img_size, img_size)),
                                  transforms.ColorJitter(0.3, 0.3, 0.3, 0.15),
                                  # transforms.RandomRotation(30),
                                  transforms.RandomHorizontalFlip(),
#                                   transforms.RandomVerticalFlip(),
#                                   transforms.RandomGrayscale(),
                                  FixedRotation([-16,-14,-12,-10,-8,-6,-4,-2,0,2,4,6,8,10,12,14,16]),
                                  transforms.ToTensor(),
                                  normalize,
                              ]))

    val_data = ValDataset(val_data_list,
                          transform=transforms.Compose([
                              transforms.Resize((img_size, img_size)),
                              # transforms.CenterCrop((500, 500)),
                              transforms.ToTensor(),
                              normalize,
                          ]))

    test_data = TestDataset(test_data_list,
                            transform=transforms.Compose([
                                transforms.Resize((img_size, img_size)),
                                # transforms.CenterCrop((500, 500)),
                                transforms.ToTensor(),
                                normalize,
                                # transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                                # transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops])),
                            ]))

    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=workers,drop_last=True)
    val_loader = DataLoader(val_data, batch_size=batch_size * 2, shuffle=False, pin_memory=False, num_workers=workers,drop_last=True)
    test_loader = DataLoader(test_data, batch_size=batch_size * 2, shuffle=False, pin_memory=False, num_workers=workers)

感觉亚军代码除了投票机制以及数据增强机制之外
并没有什么出色的地方,不知道为啥正确率那么高

季军代码

预处理

首先记录每类的样本数
设置阈值划分验证集

数据增强

对图片进行resize到224/0.875大小
然后中心裁剪,保证包含大部分特征

模型方面

  • 使用了SE模块和GeneralizedMeanPooling
class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return y


class AdaptiveConcatPool2d(nn.Module):
    def __init__(self, sz=(1,1)):
        super().__init__()
        self.ap = nn.AdaptiveAvgPool2d(sz)
        self.mp = nn.AdaptiveMaxPool2d(sz)
        
    def forward(self, x):
        return torch.cat([self.mp(x), self.ap(x)], 1)


class GeneralizedMeanPooling(nn.Module):
    def __init__(self, norm=3, output_size=1, eps=1e-6):
        super().__init__()
        assert norm > 0
        self.p = float(norm)
        self.output_size = output_size
        self.eps = eps

    def forward(self, x):
        x = x.clamp(min=self.eps).pow(self.p)
        
        return torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p)

    def __repr__(self):
        return self.__class__.__name__ + '(' \
               + str(self.p) + ', ' \
               + 'output_size=' + str(self.output_size) + ')'

主干网络采用efficientnet和resnet50
搭配以上两个模块进行多次(4次)预测结果进行投票

感觉大佬们的训练方法,数据增强技术
还有各种模块的使用都很强
希望自己有一天也能这么强

posted @ 2020-11-05 20:04  CrosseaLL  阅读(246)  评论(0编辑  收藏  举报