Pytorch torch.meshgrid() 在目标检测中的应用

概述

最近在学习目标检测的相关算法。在我看来目标检测要比分类、语义分割任务复杂的多,后者一般只需要为每个图像预测一个标签(分类)或者为每个像素预测一个标签(分割)。而目标检测还需要回归目标边界框同时进行分类,这使得目标检测的数据处理和训练比较复杂。

在目标检测中,一般是通过神经网络提取图像特征,得到下采样stride步幅的特征图,在特征图的每个cell上进行预测,最后将在特征层上预测的结果map回原图尺寸上。这时除了stride,还需要知道特征图的网格坐标。这时就可以用到torch.meshgrid()方法生成网格坐标。 

使用

在使用torch.meshgrid()前,简单说一下图像坐标。如下图所示,图像坐标的原点是左上角、x轴是宽、指向右;y轴是高、指向下。

需要注意的是,在pytorch中,tensor的shape一般都是(..., h, w)。要注意使用(x, y)时,坐标的对应关系,x对应的是w、y对应的是h。 我们需要得到grids网格坐标,也就是每个cell的左上角坐标。

这时我们就可以使用meshgrid方法。

feat = torch.randn(3, 4, 6) hsize, wsize = feat.shape[-2:] 
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)], indexing="ij") 
grids = torch.stack((xv, yv), 2)

通过torch.meshgrid()生成了yv和xv,其内容如下:

# yv tensor([[0, 0, 0, 0, 0, 0], 
#            [1, 1, 1, 1, 1, 1], 
#            [2, 2, 2, 2, 2, 2],
#            [3, 3, 3, 3, 3, 3]])
#
# (hsize, wsize)

# xv tensor([[0, 1, 2, 3, 4, 5],
#            [0, 1, 2, 3, 4, 5],
#            [0, 1, 2, 3, 4, 5],
#            [0, 1, 2, 3, 4, 5]])
#
# (hsize, wsize)

 

然后通过torch.stack()将生成的xv与yv组合起来,就得到了grid网格坐标。

# tensor([[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0], [5, 0]],
#         [[0, 1], [1, 1], [2, 1], [3, 1], [4, 1], [5, 1]],
#         [[0, 2], [1, 2], ...
#                          [2, 3], [3, 3], [4, 3], [5, 3]]
#        ])
#
# (hsize, wsize, 2)

其实torch.meshgrid()方法中的indexing变量在最初版本中是没有的。我们生成坐标时,都是将yv放在前面,组合时再将xv,yv stack起来,形成符合直观的网格坐标(先行再列的顺序,对应了图像的xy坐标系)。

而现在新版本的indexing变量可以通过indexing="ij" or indexing="xy"控制格式,现在仍然保留原始的方法,使用默认的ij index,大概是出于和之前代码保持一致。

 

示例

 为了直观的感受到grid网格坐标与原图的关系,我们可以将grid坐标乘以stride步幅后,映射到原图中。原图如下:

随便写一个下采样步幅stride=32的网络。

conv = nn.Sequential( 
    nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=2, padding=1),  # 2
    nn.ReLU(),
    nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=2, padding=1),  # 4
    nn.ReLU(),
    nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=2, padding=1),  # 8
    nn.ReLU(),
    nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=2, padding=1),  # 16
    nn.ReLU(),
    nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=2, padding=1),  # 32
)

提取特征并且使用上述方法得到网格坐标

img = Image.open("flower.jpg").resize((640, 416)).convert('RGB')
img = ToTensor()(img)  # (3, 416, 600)

stride = 32
output = conv(img)  # (3, 13, 20)

hsize, wsize = output.shape[-2:]
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)], indexing='ij')
grid = torch.stack((xv, yv), 2)

将网格相对坐标乘以步幅得到原图绝对坐标,并通过修改img的灰度值使其可视化。

grid = grid * stride
grid = grid.view(-1, 2)
for x, y in grid:
    img[:, y, x] = 0.  # (y, x) -> (h, w)

得到最终的结果,我们看到特征图的grid cell 映射回原图的样子(黑点即为每个grid cell的坐标,为左上角)。

posted @ 2023-01-25 21:30  Brisling  阅读(924)  评论(0编辑  收藏  举报