deformable attention中生成初始采样点位置(init_weights或者_reset_parameters函数)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | def _reset_parameters( self ): constant_( self .sampling_offsets.weight.data, 0 ) """初始化偏移量预测的偏置(bias), 使得初始偏移位置犹如不同大小的方形卷积核组合""" # (8,) [0, pi / 4, pi / 2, 3 * pi / 2, ..., 7 * pi / 4] thetas = torch.arange( self .n_heads, dtype = torch.float32).( 2.0 * math.pi / self .n_heads) # (8, 2) grid_init = torch.stack([thetas.cos(), thetas.sin()], - 1 ) # grid_init / grid_init.abs().max(-1, keepdi=True)[0]这步计算得到8个头对应的坐标偏移: # (1, 0), (1, 1), (0, 1), (-1, 1), (-1, 0), (-1, -1), (0, -1), (1, -1) # 从图形视觉上来看, 形成的偏移位置相当于是3x3, 5x5, 7x7, 9x9正方形卷积核(出去中心,中心是参考点本身) for i in range ( self .n_points): grid_init[:, :, i, :] * = i + 1 # 注意这里取消了梯度,只是借助nn.Parameter把数值设置进去 with torch.no_grad(): self .sampling_offsets.bias = nn.Parameters(grid_init.view( - 1 )) |
最终效果就是,初始的采样点位置相当于会分布在参考点 3x3、5x5、7x7、9x9 方形邻域。
来源:Deformable DETR论文精读+代码详解 - yejian's blog
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· Manus爆火,是硬核还是营销?
· 终于写完轮子一部分:tcp代理 了,记录一下
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
2024-01-23 转:VS code 选择指定环境下的python运行代码
2021-01-23 转:三维动作捕捉技术介绍