【YOLOv8改进】STA(Super Token Attention) 超级令牌注意力机制 (论文笔记+引入代码)

摘要

视觉Transformer在许多视觉任务上展示了卓越的性能。然而,它在浅层捕获局部特征时可能会面临高度冗余的问题。因此,使用了局部自注意力或早期阶段的卷积来减少这种冗余,但这牺牲了捕获长距离依赖的能力。一个挑战随之而来:在神经网络的早期阶段,我们是否能高效且有效地进行全局上下文建模?为解决这一问题,我们从超像素的设计中获得启示,这种设计通过减少图像基元的数量来简化后续处理,并在视觉Transformer中引入了超级令牌。超级令牌旨在为视觉内容提供有意义的语义分割,这样既减少了自注意力中的令牌数量,也保留了全局建模能力。具体而言,我们提出了一种简单而有效的超级令牌注意力(STA)机制,它包括三个步骤:首先通过稀疏关联学习从视觉令牌中抽取超级令牌,接着对这些超级令牌进行自注意力处理,最后将它们映射回原始的令牌空间。STA通过将普通的全局注意力分解为稀疏关联图与低维度注意力的乘积,极大地提高了捕获全局依赖的效率。基于STA,我们开发了一个层次化的视觉Transformer。广泛的实验显示了它在各种视觉任务上的强大性能。特别是,在没有任何额外训练数据或标签的情况下,它在ImageNet-1K上实现了86.4%的顶级准确率,以及在COCO检测任务上达到53.9的盒AP和46.8的掩码AP,在ADE20K语义分割任务上实现了51.9的mIOU。

YOLO目标检测创新改进与实战案例专栏

专栏目录: YOLO有效改进系列及项目实战目录 包含卷积,主干 注意力,检测头等创新机制 以及 各种目标检测分割项目实战案例

专栏链接: YOLO基础解析+创新改进+实战案例

创新点

  1. 引入超级标记(super tokens):通过引入超级标记的概念,实现了在视觉Transformer中的全局上下文建模。超级标记将原始标记聚合成具有语义意义的单元,从而减少了自注意力计算的复杂度,提高了全局信息的捕获效率。

  2. Super Token Attention(STA)机制:提出了一种简单而强大的超级标记注意力机制,包括超级标记采样、多头自注意力和标记上采样等步骤。STA通过稀疏映射和自注意力计算,在全局和局部之间实现了高效的信息交互,有效地学习全局表示。

  3. Hierarchical Vision Transformer:设计了一种层次化的视觉Transformer结构,结合了卷积层和超级标记Transformer块,以在不同层次上实现高效和有效的表示学习。这种结构在各种视觉任务上展现了出色的性能,包括图像分类、目标检测和语义分割等。

  4. 性能优越性:在多个视觉任务上进行了广泛的实验验证,包括ImageNet图像分类、COCO目标检测和ADE20K语义分割等。实验结果表明,基于超级标记的视觉Transformer在各项任务上均取得了优异的性能,超越了其他Transformer模型的表现。

yolov8 引入

class StokenAttention(nn.Module):
   def __init__(self, dim, stoken_size=[8,8], n_iter=1, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
       super().__init__()

       self.n_iter = n_iter  # 迭代次数
       self.stoken_size = stoken_size  # 空间令牌的大小

       self.scale = dim ** - 0.5  # 缩放因子

       self.unfold = Unfold(3)  # 定义Unfold实例
       self.fold = Fold(3)  # 定义Fold实例

       # 定义空间注意力机制
       self.stoken_refine = StAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop)

   def stoken_forward(self, x):
       '''
          x: (B, C, H, W)
       '''
       B, C, H0, W0 = x.shape
       h, w = self.stoken_size

       pad_l = pad_t = 0
       pad_r = (w - W0 % w) % w
       pad_b = (h - H0 % h) % h
       if pad_r > 0 or pad_b > 0:
           x = F.pad(x, (pad_l, pad_r, pad_t, pad_b))  # 对输入张量进行填充

       _, _, H, W = x.shape

       hh, ww = H // h, W // w

       stoken_features = F.adaptive_avg_pool2d(x, (hh, ww))  # 自适应平均池化
       pixel_features = x.reshape(B, C, hh, h, ww, w).permute(0, 2, 4, 3, 5, 1).reshape(B, hh * ww, h * w, C)

       with torch.no_grad():
           for idx in range(self.n_iter):
               stoken_features = self.unfold(stoken_features)  # 展开空间令牌特征
               stoken_features = stoken_features.transpose(1, 2).reshape(B, hh * ww, C, 9)
               affinity_matrix = pixel_features @ stoken_features * self.scale  # 计算亲和矩阵
               affinity_matrix = affinity_matrix.softmax(-1)  # 对亲和矩阵进行softmax

               affinity_matrix_sum = affinity_matrix.sum(2).transpose(1, 2).reshape(B, 9, hh, ww)
               affinity_matrix_sum = self.fold(affinity_matrix_sum)
               if idx < self.n_iter - 1:
                   stoken_features = pixel_features.transpose(-1, -2) @ affinity_matrix
                   stoken_features = self.fold(stoken_features.permute(0, 2, 3, 1).reshape(B * C, 9, hh, ww)).reshape(B, C, hh, ww)
                   stoken_features = stoken_features / (affinity_matrix_sum + 1e-12)  # 归一化

       stoken_features = pixel_features.transpose(-1, -2) @ affinity_matrix
       stoken_features = self.fold(stoken_features.permute(0, 2, 3, 1).reshape(B * C, 9, hh, ww)).reshape(B, C, hh, ww)
       stoken_features = stoken_features / (affinity_matrix_sum.detach() + 1e-12)  # 归一化

       stoken_features = self.stoken_refine(stoken_features)  # 细化空间令牌特征

       stoken_features = self.unfold(stoken_features)  # 展开细化后的特征
       stoken_features = stoken_features.transpose(1, 2).reshape(B, hh * ww, C, 9)
       pixel_features = stoken_features @ affinity_matrix.transpose(-1, -2)  # 计算最终的像素特征

       pixel_features = pixel_features.reshape(B, hh, ww, C, h, w).permute(0, 3, 1, 4, 2, 5).reshape(B, C, H, W)
       if pad_r > 0 or pad_b > 0:
           pixel_features = pixel_features[:, :, :H0, :W0]  # 去除填充部分

       return pixel_features  # 返回最终的像素特征

   def direct_forward(self, x):
       B, C, H, W = x.shape
       stoken_features = x
       stoken_features = self.stoken_refine(stoken_features)
       return stoken_features  # 返回直接计算的空间令牌特征

   def forward(self, x):
       if self.stoken_size[0] > 1 or self.stoken_size[1] > 1:
           return self.stoken_forward(x)  # 使用空间令牌前向计算
       else:
           return self.direct_forward(x)  # 直接前向计算

task与yaml配置

详见:https://blog.csdn.net/shangyanaf/article/details/139113660

posted @ 2024-06-20 21:29  YOLOv8大师  阅读(13)  评论(0编辑  收藏  举报