deformable attention中生成初始采样点位置(init_weights或者_reset_parameters函数)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def _reset_parameters(self):
  constant_(self.sampling_offsets.weight.data, 0)
  """初始化偏移量预测的偏置(bias), 使得初始偏移位置犹如不同大小的方形卷积核组合"""
 
  # (8,) [0, pi / 4, pi / 2, 3 * pi / 2, ..., 7 * pi / 4]
  thetas = torch.arange(self.n_heads, dtype = torch.float32).(2.0 * math.pi / self.n_heads)
  # (8, 2)
  grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
  # grid_init / grid_init.abs().max(-1, keepdi=True)[0]这步计算得到8个头对应的坐标偏移:
  # (1, 0), (1, 1), (0, 1), (-1, 1), (-1, 0), (-1, -1), (0, -1), (1, -1)
  # 从图形视觉上来看, 形成的偏移位置相当于是3x3, 5x5, 7x7, 9x9正方形卷积核(出去中心,中心是参考点本身)
  for i in range(self.n_points):
    grid_init[:, :, i, :] *= i + 1
  # 注意这里取消了梯度,只是借助nn.Parameter把数值设置进去
  with torch.no_grad():
    self.sampling_offsets.bias = nn.Parameters(grid_init.view(-1))

最终效果就是,初始的采样点位置相当于会分布在参考点 3x3、5x5、7x7、9x9 方形邻域。

来源:Deformable DETR论文精读+代码详解 - yejian's blog

 

posted @   Picassooo  阅读(25)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· Manus爆火,是硬核还是营销?
· 终于写完轮子一部分:tcp代理 了,记录一下
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
历史上的今天:
2024-01-23 转:VS code 选择指定环境下的python运行代码
2021-01-23 转:三维动作捕捉技术介绍
点击右上角即可分享
微信分享提示