Pytorch torch.meshgrid() 在目标检测中的应用
概述
最近在学习目标检测的相关算法。在我看来目标检测要比分类、语义分割任务复杂的多,后者一般只需要为每个图像预测一个标签(分类)或者为每个像素预测一个标签(分割)。而目标检测还需要回归目标边界框同时进行分类,这使得目标检测的数据处理和训练比较复杂。
在目标检测中,一般是通过神经网络提取图像特征,得到下采样stride步幅的特征图,在特征图的每个cell上进行预测,最后将在特征层上预测的结果map回原图尺寸上。这时除了stride,还需要知道特征图的网格坐标。这时就可以用到torch.meshgrid()方法生成网格坐标。
使用
在使用torch.meshgrid()前,简单说一下图像坐标。如下图所示,图像坐标的原点是左上角、x轴是宽、指向右;y轴是高、指向下。
需要注意的是,在pytorch中,tensor的shape一般都是(..., h, w)。要注意使用(x, y)时,坐标的对应关系,x对应的是w、y对应的是h。 我们需要得到grids网格坐标,也就是每个cell的左上角坐标。
这时我们就可以使用meshgrid方法。
feat = torch.randn(3, 4, 6) hsize, wsize = feat.shape[-2:]
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)], indexing="ij")
grids = torch.stack((xv, yv), 2)
通过torch.meshgrid()生成了yv和xv,其内容如下:
# yv tensor([[0, 0, 0, 0, 0, 0],
# [1, 1, 1, 1, 1, 1],
# [2, 2, 2, 2, 2, 2],
# [3, 3, 3, 3, 3, 3]])
#
# (hsize, wsize)
# xv tensor([[0, 1, 2, 3, 4, 5],
# [0, 1, 2, 3, 4, 5],
# [0, 1, 2, 3, 4, 5],
# [0, 1, 2, 3, 4, 5]])
#
# (hsize, wsize)
然后通过torch.stack()将生成的xv与yv组合起来,就得到了grid网格坐标。
# tensor([[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0], [5, 0]],
# [[0, 1], [1, 1], [2, 1], [3, 1], [4, 1], [5, 1]],
# [[0, 2], [1, 2], ...
# [2, 3], [3, 3], [4, 3], [5, 3]]
# ])
#
# (hsize, wsize, 2)
其实torch.meshgrid()方法中的indexing变量在最初版本中是没有的。我们生成坐标时,都是将yv放在前面,组合时再将xv,yv stack起来,形成符合直观的网格坐标(先行再列的顺序,对应了图像的xy坐标系)。
而现在新版本的indexing变量可以通过indexing="ij" or indexing="xy"控制格式,现在仍然保留原始的方法,使用默认的ij index,大概是出于和之前代码保持一致。
示例
为了直观的感受到grid网格坐标与原图的关系,我们可以将grid坐标乘以stride步幅后,映射到原图中。原图如下:
随便写一个下采样步幅stride=32的网络。
conv = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=2, padding=1), # 2
nn.ReLU(),
nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=2, padding=1), # 4
nn.ReLU(),
nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=2, padding=1), # 8
nn.ReLU(),
nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=2, padding=1), # 16
nn.ReLU(),
nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=2, padding=1), # 32
)
提取特征并且使用上述方法得到网格坐标
img = Image.open("flower.jpg").resize((640, 416)).convert('RGB')
img = ToTensor()(img) # (3, 416, 600)
stride = 32
output = conv(img) # (3, 13, 20)
hsize, wsize = output.shape[-2:]
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)], indexing='ij')
grid = torch.stack((xv, yv), 2)
将网格相对坐标乘以步幅得到原图绝对坐标,并通过修改img的灰度值使其可视化。
grid = grid * stride
grid = grid.view(-1, 2)
for x, y in grid:
img[:, y, x] = 0. # (y, x) -> (h, w)
得到最终的结果,我们看到特征图的grid cell 映射回原图的样子(黑点即为每个grid cell的坐标,为左上角)。