【YOLOv8改进】DAT(Deformable Attention):可变性注意力 (论文笔记+引入代码)

YOLO目标检测创新改进与实战案例专栏

专栏目录: YOLO有效改进系列及项目实战目录 包含卷积,主干 注意力,检测头等创新机制 以及 各种目标检测分割项目实战案例

专栏链接: YOLO基础解析+创新改进+实战案例

摘要

Transformers最近在各种视觉任务中展现出了优越的性能。较大甚至是全局的感受野赋予了Transformer模型比其卷积神经网络(CNN)对手更强的表征能力。然而,简单地扩大感受野也带来了几个问题。一方面,使用密集注意力(例如在ViT中)会导致过高的内存和计算成本,并且特征可能会受到兴趣区域之外的无关部分的影响。另一方面,PVT或Swin Transformer采用的稀疏注意力对数据不敏感,可能限制了建模长距离关系的能力。为了解决这些问题,我们提出了一种新型的可变形自注意力模块,其中在自注意力中键和值对的位置是以数据为基础选择的。这种灵活的方案使自注意力模块能够聚焦于相关区域并捕捉更多信息特征。在此基础上,我们提出了Deformable Attention Transformer,这是一种用于图像分类和密集预测任务的通用主干模型,具有可变形注意力。广泛的实验表明,我们的模型在综合基准测试中实现了持续改进的结果。代码可在https://github.com/LeapLabTHU/DAT获取。

基本原理

关键

  1. 数据依赖的位置选择:Deformable Attention允许在自注意力机制中以数据依赖的方式选择键和值对的位置,使模型能够根据输入数据动态调整注意力的焦点。
  2. 灵活的偏移学习:通过学习偏移量,Deformable Attention可以将关键点和值移动到重要区域,从而提高模型对关键特征的捕获能力。
  3. 全局键共享:Deformable Attention学习一组全局键,这些键在不同的视觉标记之间共享,有助于模型捕获长距离的相关性。
  4. 空间自适应机制:Deformable Attention可以根据输入数据的特征动态调整注意力模式,从而适应不同的视觉任务和场景。

通过相对于Swin-Transformer和PVT的改进,加入了可变形机制,同时控制网络不增加太多的计算量。作者认为,缩小q对应的k的范围,能够减少无关信息的干扰,增强信息的捕捉,于是引入了DCN机制到注意力模块中,提出了一种新的注意力模块:可变形多头注意力模块。该模块通过对k和v进行DCN偏移后再计算注意力,从而提升了性能。

在可变形多头注意力模块中,输入特征图像 $x \in \mathbb{R}^{H \times W \times C}$ 生成一个参考网格,其中参考点 $p \in \mathbb{R}^{H_G \times W_G \times 2}$。该网格是从输入特征图 $x$ 降采样而来,降采样系数为 $r$, $H_G = H / r, W_G = W / r$。参考点的值代表的是坐标值 $(0, 0), \ldots, (H_G - 1, W_G - 1)$,再归一化到 $[-1, +1]$。

输入特征图像 $x$ 通过线性投影得到 $q = x W_q$,再输入到一个轻量级子网络offset network,生成偏移量 $\Delta p = \theta_{\text{offset}}(q)$。为了稳定训练过程,使用了一些预定义的因子来衡量 $\Delta p$ 的振幅,以防止太大的offset,即 $\Delta p \leftarrow \text{sinh}(\Delta p)$。

然后将获得的offset作用在参考点上,获得变形点的位置,进行特征采样(双线性插值)得到 $\hat{x}$,再通过投影矩阵生成Key和Value, $\hat{k} = \hat{x} W_k, \hat{v} = \hat{x} W_v$。

$qkv$进行多头注意力计算,同时加入相对位置偏移嵌入。最后将获得的多头特征拼接起来,通过投影矩阵获得最终的注意力模块输出 $Z$。

yolov8 引入

class DAttentionBaseline(nn.Module):

   def __init__(
       self, q_size, kv_size, n_heads, n_head_channels, n_groups,
       attn_drop, proj_drop, stride, 
       offset_range_factor, use_pe, dwc_pe,
       no_off, fixed_pe, ksize, log_cpb
   ):
       # 初始化函数,定义了所需的参数
       super().__init__()
       self.dwc_pe = dwc_pe  # 是否使用深度卷积位置编码
       self.n_head_channels = n_head_channels  # 每个头的通道数
       self.scale = self.n_head_channels ** -0.5  # 缩放因子,等于每个头的通道数的负0.5次方
       self.n_heads = n_heads  # 多头注意力机制中的头数
       self.q_h, self.q_w = q_size  # query的高和宽
       self.kv_h, self.kv_w = self.q_h // stride, self.q_w // stride  # 计算键值对的高和宽
       self.nc = n_head_channels * n_heads  # 总的通道数
       self.n_groups = n_groups  # 分组数
       self.n_group_channels = self.nc // self.n_groups  # 每组的通道数
       self.n_group_heads = self.n_heads // self.n_groups  # 每组的头数
       self.use_pe = use_pe  # 是否使用位置编码
       self.fixed_pe = fixed_pe  # 是否使用固定的位置编码
       self.no_off = no_off  # 是否禁用偏移
       self.offset_range_factor = offset_range_factor  # 偏移范围因子
       self.ksize = ksize  # 卷积核尺寸
       self.log_cpb = log_cpb  # 是否使用对数相对位置偏置
       self.stride = stride  # 步幅
       kk = self.ksize
       pad_size = kk // 2 if kk != stride else 0  # 计算填充大小

       # 定义卷积偏移网络
       self.conv_offset = nn.Sequential(
           nn.Conv2d(self.n_group_channels, self.n_group_channels, kk, stride, pad_size, groups=self.n_group_channels),
           LayerNormProxy(self.n_group_channels),  # 使用LayerNormProxy进行归一化
           nn.GELU(),  # 使用GELU激活函数
           nn.Conv2d(self.n_group_channels, 2, 1, 1, 0, bias=False)  # 输出偏移量
       )
       if self.no_off:
           for m in self.conv_offset.parameters():
               m.requires_grad_(False)  # 如果不使用偏移,禁用偏移网络的参数更新

       # 定义投影层
       self.proj_q = nn.Conv2d(
           self.nc, self.nc,
           kernel_size=1, stride=1, padding=0  # query投影
       )

       self.proj_k = nn.Conv2d(
           self.nc, self.nc,
           kernel_size=1, stride=1, padding=0  # key投影
       )

       self.proj_v = nn.Conv2d(
           self.nc, self.nc,
           kernel_size=1, stride=1, padding=0  # value投影
       )

       self.proj_out = nn.Conv2d(
           self.nc, self.nc,
           kernel_size=1, stride=1, padding=0  # 输出投影
       )

       self.proj_drop = nn.Dropout(proj_drop, inplace=True)  # 投影层的Dropout
       self.attn_drop = nn.Dropout(attn_drop, inplace=True)  # 注意力层的Dropout

       # 相对位置嵌入的定义
       if self.use_pe and not self.no_off:
           if self.dwc_pe:
               self.rpe_table = nn.Conv2d(
                   self.nc, self.nc, kernel_size=3, stride=1, padding=1, groups=self.nc)  # 深度卷积位置编码
           elif self.fixed_pe:
               self.rpe_table = nn.Parameter(
                   torch.zeros(self.n_heads, self.q_h * self.q_w, self.kv_h * self.kv_w)
               )
               trunc_normal_(self.rpe_table, std=0.01)  # 截断正态分布初始化
           elif self.log_cpb:
               # 借用自Swin-V2
               self.rpe_table = nn.Sequential(
                   nn.Linear(2, 32, bias=True),
                   nn.ReLU(inplace=True),
                   nn.Linear(32, self.n_group_heads, bias=False)
               )
           else:
               self.rpe_table = nn.Parameter(
                   torch.zeros(self.n_heads, self.q_h * 2 - 1, self.q_w * 2 - 1)
               )
               trunc_normal_(self.rpe_table, std=0.01)  # 截断正态分布初始化
       else:
           self.rpe_table = None

   @torch.no_grad()
   def _get_ref_points(self, H_key, W_key, B, dtype, device):
       # 获取参考点
       ref_y, ref_x = torch.meshgrid(
           torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype, device=device),
           torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype, device=device),
           indexing='ij'  # 保持矩阵索引一致
       )
       ref = torch.stack((ref_y, ref_x), -1)  # 堆叠y和x坐标
       ref[..., 1].div_(W_key - 1.0).mul_(2.0).sub_(1.0)  # 归一化到[-1, 1]范围
       ref[..., 0].div_(H_key - 1.0).mul_(2.0).sub_(1.0)  # 归一化到[-1, 1]范围
       ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1)  # 扩展维度,适应批量和分组

       return ref
   
   @torch.no_grad()
   def _get_q_grid(self, H, W, B, dtype, device):
       # 获取query网格
       ref_y, ref_x = torch.meshgrid(
           torch.arange(0, H, dtype=dtype, device=device),
           torch.arange(0, W, dtype=dtype, device=device),
           indexing='ij'  # 保持矩阵索引一致
       )
       ref = torch.stack((ref_y, ref_x), -1)  # 堆叠y和x坐标
       ref[..., 1].div_(W - 1.0).mul_(2.0).sub_(1.0)  # 归一化到[-1, 1]范围
       ref[..., 0].div_(H - 1.0).mul_(2.0).sub_(1.0)  # 归一化到[-1, 1]范围
       ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1)  # 扩展维度,适应批量和分组

       return ref

   def forward(self, x):
       # 前向传播函数
       B, C, H, W = x.size()  # 获取输入的尺寸
       dtype, device = x.dtype, x.device

       q = self.proj_q(x)  # 对输入x进行query投影
       q_off = einops.rearrange(q, 'b (g c) h w -> (b g) c h w', g=self.n_groups, c=self.n_group_channels)  # 重排列tensor的维度
       offset = self.conv_offset(q_off).contiguous()  # 计算偏移量
       Hk, Wk = offset.size(2), offset.size(3)  # 获取偏移量的高和宽
       n_sample = Hk * Wk  # 计算采样点数量

       if self.offset_range_factor >= 0 and not self.no_off:
           offset_range = torch.tensor([1.0 / (Hk - 1.0), 1.0 / (Wk - 1.0)], device=device).reshape(1, 2, 1, 1)
           offset = offset.tanh().mul(offset_range).mul(self.offset_range_factor)

       offset = einops.rearrange(offset, 'b p h w -> b h w p')
       reference = self._get_ref_points(Hk, Wk, B, dtype, device)

       if self.no_off:
           offset = offset.fill_(0.0)

       if self.offset_range_factor >= 0:
           pos = offset + reference
       else:
           pos = (offset + reference).clamp(-1., +1.)

       if self.no_off:
           x_sampled = F.avg_pool2d(x, kernel_size=self.stride, stride=self.stride)
           assert x_sampled.size(2) == Hk and x_sampled.size(3) == Wk, f"Size is {x_sampled.size()}"
       else:
           x_sampled = F.grid_sample(
               input=x.reshape(B * self.n_groups, self.n_group_channels, H, W), 
               grid=pos[..., (1, 0)],  # y, x -> x, y
               mode='bilinear', align_corners=True)  # 进行双线性插值采样

       x_sampled = x_sampled.reshape(B, C, 1, n_sample)

       q = q.reshape(B * self.n_heads, self.n_head_channels, H * W)
       k = self.proj_k(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)
       v = self.proj_v(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)

       attn = torch.einsum('b c m, b c n -> b m n', q, k)  # 计算注意力权重
       attn = attn.mul(self.scale)

       if self.use_pe and (not self.no_off):
           if self.dwc_pe:
               residual_lepe = self.rpe_table(q.reshape(B, C, H, W)).reshape(B * self.n_heads, self.n_head_channels, H * W)
           elif self.fixed_pe:
               rpe_table = self.rpe_table
               attn_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
               attn = attn + attn_bias.reshape(B * self.n_heads, H * W, n_sample)
           elif self.log_cpb:
               q_grid = self._get_q_grid(H, W, B, dtype, device)
               displacement = (q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups, n_sample, 2).unsqueeze(1)).mul(4.0)  # 计算位移
               displacement = torch.sign(displacement) * torch.log2(torch.abs(displacement) + 1.0) / np.log2(8.0)
               attn_bias = self.rpe_table(displacement)  # 计算相对位置嵌入偏置
               attn = attn + einops.rearrange(attn_bias, 'b m n h -> (b h) m n', h=self.n_group_heads)
           else:
               rpe_table = self.rpe_table
               rpe_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
               q_grid = self._get_q_grid(H, W, B, dtype, device)
               displacement = (q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups, n_sample, 2).unsqueeze(1)).mul(0.5)
               attn_bias = F.grid_sample(
                   input=einops.rearrange(rpe_bias, 'b (g c) h w -> (b g) c h w', c=self.n_group_heads, g=self.n_groups),
                   grid=displacement[..., (1, 0)],
                   mode='bilinear', align_corners=True)  # 双线性插值计算相对位置偏置

               attn_bias = attn_bias.reshape(B * self.n_heads, H * W, n_sample)
               attn = attn + attn_bias

       attn = F.softmax(attn, dim=2)  # 对注意力权重进行softmax
       attn = self.attn_drop(attn)

       out = torch.einsum('b m n, b c n -> b c m', attn, v)  # 计算注意力输出

       if self.use_pe and self.dwc_pe:
           out = out + residual_lepe
       out = out.reshape(B, C, H, W)

       y = self.proj_drop(self.proj_out(out))  # 投影输出并进行Dropout

       return y, pos.reshape(B, self.n_groups, Hk, Wk, 2), reference.reshape(B, self.n_groups, Hk, Wk, 2)

task与yaml配置

详见:https://blog.csdn.net/shangyanaf/article/details/139193465

posted @ 2024-06-06 20:42  YOLOv8大师  阅读(186)  评论(0编辑  收藏  举报