CvT: 如何将卷积的优势融入Transformer

【GiantPandaCV导语】与之前BoTNet不同,CvT虽然题目中有卷积的字样,但是实际总体来说依然是以Transformer Block为主的,在Token的处理方面引入了卷积,从而为模型带来的局部性。最终CvT最高拿下了87.7%的Top1准确率。

引言

CvT架构的Motivation也是将局部性引入Vision Transformer架构中,期望通过引入局部性得到更高的性能和效率权衡。因此我们主要关注CvT是如何引入局部性的。具体来说提出了两点改进:

  • Convolutional token embedding
  • Convolutional Projection

通过以上改进,模型不仅具有卷积的优势(局部感受野、权重共享、空间下采样等特性带来的优势),如平移不变形、尺度不变性、旋转不变性等,也保持了Self Attention的优势,如动态注意力、全局语义信息、更强的泛化能力等。

展开一点讲,Convolutional Vision Transformer有两点核心:

  • 第一步,参考CNN的架构,将Transformer也设计为多阶段的层次架构,每个stage之前使用convolutional token embedding,通过使用卷积+layer normalization能够实现降维的功能(注:逐渐降低序列长度的同时,增加每个token的维度,可以类比卷积中feature map砍半,通道数增加的操作)
  • 第二步,使用Convolutional Projection取代原来的Linear Projection,该模块实际使用的是深度可分离卷积实现,这样也能有效捕获局部语义信息。

需要注意的是:CvT去掉了Positional Embedding模块,发现对模型性能没有任何影响。认为可以简化架构的设计,并且可以在分辨率变化的情况下更容易适配。

比较

在相关工作部分,CvT总结了一个表格,比较方便对比:

方法

在引言部分已经讲得比较详细了,下面对照架构图复盘一下(用尽可能通俗的语言描述):

  • 绿色框是conv token embedding操作,通俗来讲,使用了超大卷积核来提升局部性不足的问题。
  • 右图蓝色框中展示的是改进的self attention,通俗来讲,使用了non local的操作,使用深度可分离卷积取代MLP做Projection,如下图所示:

  • 如图(a)所示,Vision Transformer中使用的是MLP进行Linear Projection, 这样的信息是全局性的,但是计算量比较大。
  • 如图(b)所示,使用卷积进行映射,这种操作类似Non Local Network,使用卷积进行映射。
  • 如图(c)所示,使用的是带stride的卷积进行压缩,这样做是处于对效率的考量,token数量可以降低四倍,会带来一定的性能损失。

Positional embedding探讨:

由于Convolutional Projection在每个Transformer Block中都是用,配合Convolutional Token Embedding操作,能够给模型足够的能力来建模局部空间关系,因此可以去掉Transformer中的Positional Embedding操作。从下表发现,pe对模型性能影响不大。

与其他工作的对比:

  • 同期工作1:Tokens-to-Tokens ViT: 使用Progressive Tokenization整合临近token,使用Transformer-based骨干网络具有局部性的同时,还能降低token序列长度。
  • 区别:CvT使用的是multi-stage的过程,token长度降低的同时,其维度在增加,从而保证模型的容量。同时计算量相比T2T有所改善。
  • 同期工作2:Pyramid Vision Transformer(PVT): 引入了金字塔架构,使得PVT可以作为Backbone应用于Dense prediction任务中。
  • 区别:CvT也使用了金字塔架构,区别在于CvT中提出使用stride卷积来实现空间降采样,进一步融合了局部信息。

最终模型架构如下:

实验

左图中令人感兴趣的是BiT,这篇是谷歌的文章big transfer,探究CNN架构在大规模数据与训练的效果,可以看出即便是纯CNN架构模型参数量也可以非常巨大,而Vision Transformer还有CvT等在同等精度下模型参数量远小于BiT,这一定程度上说明了Transformer结合CNN在数据量足够的情况下性能可以非常可观,要比单纯CNN架构的模型性能更优。

右图展示了CvT和几种vision transformer架构的性能比较,可见CvT在权衡方面做的非常不错。

与SOTA比较:

有趣的是CvT-13-NAS也采用了搜索的方法DA-NAS,主要搜索对象是key和value的stride,以及MLP的Expansion Ratio, 最终搜索的结果要比Baseline略好。

在无需JFT数据集的情况下,CvT最高调整可以达到87.7%的top1 准确率。

其他数据集结果:

消融实验

代码

Convolutional Token Embedding代码实现:可以看出,实际上就是大卷积核+大Stride的滑动引入的局部性。

class ConvEmbed(nn.Module):
    """ Image to Conv Embedding
    """
    def __init__(self,
                 patch_size=7,
                 in_chans=3,
                 embed_dim=64,
                 stride=4,
                 padding=2,
                 norm_layer=None):
        super().__init__()
        patch_size = to_2tuple(patch_size)
        self.patch_size = patch_size

        self.proj = nn.Conv2d(
            in_chans, embed_dim,
            kernel_size=patch_size,
            stride=stride,
            padding=padding
        )
        self.norm = norm_layer(embed_dim) if norm_layer else None

    def forward(self, x):
        x = self.proj(x)

        B, C, H, W = x.shape
        x = rearrange(x, 'b c h w -> b (h w) c')
        if self.norm:
            x = self.norm(x)
        x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)

        return x

Convolutional Projection代码实现,具体看_build_projection函数:

class Attention(nn.Module):
    def __init__(self,
                 dim_in,
                 dim_out,
                 num_heads,
                 qkv_bias=False,
                 attn_drop=0.,
                 proj_drop=0.,
                 method='dw_bn',
                 kernel_size=3,
                 stride_kv=1,
                 stride_q=1,
                 padding_kv=1,
                 padding_q=1,
                 with_cls_token=True,
                 **kwargs
                 ):
        super().__init__()
        self.stride_kv = stride_kv
        self.stride_q = stride_q
        self.dim = dim_out
        self.num_heads = num_heads
        # head_dim = self.qkv_dim // num_heads
        self.scale = dim_out ** -0.5
        self.with_cls_token = with_cls_token

        self.conv_proj_q = self._build_projection(
            dim_in, dim_out, kernel_size, padding_q,
            stride_q, 'linear' if method == 'avg' else method
        )
        self.conv_proj_k = self._build_projection(
            dim_in, dim_out, kernel_size, padding_kv,
            stride_kv, method
        )
        self.conv_proj_v = self._build_projection(
            dim_in, dim_out, kernel_size, padding_kv,
            stride_kv, method
        )

        self.proj_q = nn.Linear(dim_in, dim_out, bias=qkv_bias)
        self.proj_k = nn.Linear(dim_in, dim_out, bias=qkv_bias)
        self.proj_v = nn.Linear(dim_in, dim_out, bias=qkv_bias)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim_out, dim_out)
        self.proj_drop = nn.Dropout(proj_drop)

    def _build_projection(self,
                          dim_in,
                          dim_out,
                          kernel_size,
                          padding,
                          stride,
                          method):
        if method == 'dw_bn':
            proj = nn.Sequential(OrderedDict([
                ('conv', nn.Conv2d(
                    dim_in,
                    dim_in,
                    kernel_size=kernel_size,
                    padding=padding,
                    stride=stride,
                    bias=False,
                    groups=dim_in
                )),
                ('bn', nn.BatchNorm2d(dim_in)),
                ('rearrage', Rearrange('b c h w -> b (h w) c')),
            ]))
        elif method == 'avg':
            proj = nn.Sequential(OrderedDict([
                ('avg', nn.AvgPool2d(
                    kernel_size=kernel_size,
                    padding=padding,
                    stride=stride,
                    ceil_mode=True
                )),
                ('rearrage', Rearrange('b c h w -> b (h w) c')),
            ]))
        elif method == 'linear':
            proj = None
        else:
            raise ValueError('Unknown method ({})'.format(method))

        return proj

    def forward_conv(self, x, h, w):
        if self.with_cls_token:
            cls_token, x = torch.split(x, [1, h*w], 1)

        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)

        if self.conv_proj_q is not None:
            q = self.conv_proj_q(x)
        else:
            q = rearrange(x, 'b c h w -> b (h w) c')

        if self.conv_proj_k is not None:
            k = self.conv_proj_k(x)
        else:
            k = rearrange(x, 'b c h w -> b (h w) c')

        if self.conv_proj_v is not None:
            v = self.conv_proj_v(x)
        else:
            v = rearrange(x, 'b c h w -> b (h w) c')

        if self.with_cls_token:
            q = torch.cat((cls_token, q), dim=1)
            k = torch.cat((cls_token, k), dim=1)
            v = torch.cat((cls_token, v), dim=1)

        return q, k, v

    def forward(self, x, h, w):
        if (
            self.conv_proj_q is not None
            or self.conv_proj_k is not None
            or self.conv_proj_v is not None
        ):
            q, k, v = self.forward_conv(x, h, w)

        q = rearrange(self.proj_q(q), 'b t (h d) -> b h t d', h=self.num_heads)
        k = rearrange(self.proj_k(k), 'b t (h d) -> b h t d', h=self.num_heads)
        v = rearrange(self.proj_v(v), 'b t (h d) -> b h t d', h=self.num_heads)

        attn_score = torch.einsum('bhlk,bhtk->bhlt', [q, k]) * self.scale
        attn = F.softmax(attn_score, dim=-1)
        attn = self.attn_drop(attn)

        x = torch.einsum('bhlt,bhtv->bhlv', [attn, v])
        x = rearrange(x, 'b h t d -> b t (h d)')

        x = self.proj(x)
        x = self.proj_drop(x)

        return x

参考

https://github.com/microsoft/CvT/blob/main/lib/models/cls_cvt.py

https://arxiv.org/pdf/2103.15808.pdf

https://zhuanlan.zhihu.com/p/142864566

笔者在cifar10数据集上修改了CvT中的Stride等参数,在不用任何数据增强和Trick的情况下得到了下图结果,Top1为84.74。虽然看上去性能比较差,但是这还没有调参以及加上数据增强方法,只训练了200个epoch的结果。

python train.py --model 'cvt' --name "cvt" --sched 'cosine' --epochs 200 --lr 0.01

感兴趣的可以点击下面链接调参:

https://github.com/pprp/pytorch-cifar-model-zoo

posted @ 2022-01-01 16:11  pprp  阅读(656)  评论(0编辑  收藏  举报