pytorch ssd 代码疑惑, flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0)
def forward(self, loc_data, conf_data, prior_data):
"""
Args:
loc_data: (tensor) Loc preds from loc layers
Shape: [batch,num_priors*4]
conf_data: (tensor) Shape: Conf preds from conf layers
Shape: [batch*num_priors,num_classes]
prior_data: (tensor) Prior boxes and variances from priorbox layers
Shape: [1,num_priors,4]
"""
num = loc_data.size(0) # batch size
num_priors = prior_data.size(0)
output = torch.zeros(num, self.num_classes, self.top_k, 5)
conf_preds = conf_data.view(num, num_priors,
self.num_classes).transpose(2, 1)
# Decode predictions into bboxes.
for i in range(num):
decoded_boxes = decode(loc_data[i], prior_data, self.variance)
# For each class, perform nms
conf_scores = conf_preds[i].clone()
for cl in range(1, self.num_classes):
c_mask = conf_scores[cl].gt(self.conf_thresh)
scores = conf_scores[cl][c_mask]
if scores.size(0) == 0:
continue
l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes)
boxes = decoded_boxes[l_mask].view(-1, 4)
# idx of highest scoring and non-overlapping boxes per class
ids, count = nms(boxes, scores, self.nms_thresh, self.top_k)
output[i, cl, :count] = \
torch.cat((scores[ids[:count]].unsqueeze(1),
boxes[ids[:count]]), 1)
flt = output.contiguous().view(num, -1, 5)
_, idx = flt[:, :, 0].sort(1, descending=True)
_, rank = idx.sort(1)
flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0)
return output
这段代码疑惑:
flt = output.contiguous().view(num, -1, 5)
_, idx = flt[:, :, 0].sort(1, descending=True)
_, rank = idx.sort(1)
flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0)
return output
自己写了测试函数测试这段代码,发现对output没有任何影响啊
import torch
batch_size = 1
num_classes = 2
top_k = 4
#output[1, 2, 4, 5]
output = torch.zeros(batch_size, num_classes, top_k, 5) #[b,21,200,5]
a0 = torch.rand(4, 5)
a1 = torch.rand(4, 5)
output[0, 0, :] = a0
output[0, 1, :] = a1
print("==================== output==")
print(output)
flt = output.contiguous().view(batch_size, -1, 5) # [b,21*200,5]
print("==================== flt==")
print(flt)
_, idx = flt[:, :, 0].sort(1, descending=True)
_, rank = idx.sort(1)
flt[(rank < top_k).unsqueeze(-1).expand_as(flt)].fill_(-100) ##src
#flt[(rank >= top_k).unsqueeze(-1).expand_as(flt)] = -100
print("====================last flt==")
print(flt)
print("====================last output==")
print(output)
==================== output==
tensor([[[[0.8621, 0.2626, 0.6104, 0.9218, 0.3547],
[0.2925, 0.8051, 0.8366, 0.7753, 0.0779],
[0.4998, 0.7976, 0.3025, 0.4936, 0.8532],
[0.0884, 0.6303, 0.1796, 0.3239, 0.7133]],
[[0.9649, 0.0333, 0.3988, 0.6702, 0.7215],
[0.6214, 0.2352, 0.2797, 0.5770, 0.3067],
[0.1836, 0.9779, 0.6925, 0.6443, 0.2149],
[0.0182, 0.4632, 0.8495, 0.2121, 0.5690]]]])
==================== flt==
tensor([[[0.8621, 0.2626, 0.6104, 0.9218, 0.3547],
[0.2925, 0.8051, 0.8366, 0.7753, 0.0779],
[0.4998, 0.7976, 0.3025, 0.4936, 0.8532],
[0.0884, 0.6303, 0.1796, 0.3239, 0.7133],
[0.9649, 0.0333, 0.3988, 0.6702, 0.7215],
[0.6214, 0.2352, 0.2797, 0.5770, 0.3067],
[0.1836, 0.9779, 0.6925, 0.6443, 0.2149],
[0.0182, 0.4632, 0.8495, 0.2121, 0.5690]]])
====================last flt==
tensor([[[0.8621, 0.2626, 0.6104, 0.9218, 0.3547],
[0.2925, 0.8051, 0.8366, 0.7753, 0.0779],
[0.4998, 0.7976, 0.3025, 0.4936, 0.8532],
[0.0884, 0.6303, 0.1796, 0.3239, 0.7133],
[0.9649, 0.0333, 0.3988, 0.6702, 0.7215],
[0.6214, 0.2352, 0.2797, 0.5770, 0.3067],
[0.1836, 0.9779, 0.6925, 0.6443, 0.2149],
[0.0182, 0.4632, 0.8495, 0.2121, 0.5690]]])
====================last output==
tensor([[[[0.8621, 0.2626, 0.6104, 0.9218, 0.3547],
[0.2925, 0.8051, 0.8366, 0.7753, 0.0779],
[0.4998, 0.7976, 0.3025, 0.4936, 0.8532],
[0.0884, 0.6303, 0.1796, 0.3239, 0.7133]],
[[0.9649, 0.0333, 0.3988, 0.6702, 0.7215],
[0.6214, 0.2352, 0.2797, 0.5770, 0.3067],
[0.1836, 0.9779, 0.6925, 0.6443, 0.2149],
[0.0182, 0.4632, 0.8495, 0.2121, 0.5690]]]])
Process finished with exit code 0
看了issue,有人也发现这个问题了。
https://github.com/amdegroot/ssd.pytorch/issues/168
可能是在pyt0.4上面有效,在高版本上面就无用了吧。
评论有人给出了解决方案就是代码里面我注释的那句话,
flt[(rank < top_k).unsqueeze(-1).expand_as(flt)].fill_(-100) ##src
改为
flt[(rank >= top_k).unsqueeze(-1).expand_as(flt)] = -100
这样是有效的,可以对output修改,输出如下:
==================== output==
tensor([[[[0.2341, 0.2941, 0.4434, 0.2481, 0.7296],
[0.3081, 0.6865, 0.7391, 0.9371, 0.1801],
[0.9775, 0.9983, 0.1749, 0.1505, 0.1860],
[0.0919, 0.7764, 0.6790, 0.7079, 0.6412]],
[[0.5518, 0.2866, 0.6437, 0.1184, 0.8749],
[0.6722, 0.4248, 0.6839, 0.9222, 0.8995],
[0.6662, 0.9287, 0.3097, 0.6207, 0.5590],
[0.6176, 0.4586, 0.5354, 0.6958, 0.4959]]]])
==================== flt==
tensor([[[0.2341, 0.2941, 0.4434, 0.2481, 0.7296],
[0.3081, 0.6865, 0.7391, 0.9371, 0.1801],
[0.9775, 0.9983, 0.1749, 0.1505, 0.1860],
[0.0919, 0.7764, 0.6790, 0.7079, 0.6412],
[0.5518, 0.2866, 0.6437, 0.1184, 0.8749],
[0.6722, 0.4248, 0.6839, 0.9222, 0.8995],
[0.6662, 0.9287, 0.3097, 0.6207, 0.5590],
[0.6176, 0.4586, 0.5354, 0.6958, 0.4959]]])
====================last flt==
tensor([[[-100.0000, -100.0000, -100.0000, -100.0000, -100.0000],
[-100.0000, -100.0000, -100.0000, -100.0000, -100.0000],
[ 0.9775, 0.9983, 0.1749, 0.1505, 0.1860],
[-100.0000, -100.0000, -100.0000, -100.0000, -100.0000],
[-100.0000, -100.0000, -100.0000, -100.0000, -100.0000],
[ 0.6722, 0.4248, 0.6839, 0.9222, 0.8995],
[ 0.6662, 0.9287, 0.3097, 0.6207, 0.5590],
[ 0.6176, 0.4586, 0.5354, 0.6958, 0.4959]]])
====================last output==
tensor([[[[-100.0000, -100.0000, -100.0000, -100.0000, -100.0000],
[-100.0000, -100.0000, -100.0000, -100.0000, -100.0000],
[ 0.9775, 0.9983, 0.1749, 0.1505, 0.1860],
[-100.0000, -100.0000, -100.0000, -100.0000, -100.0000]],
[[-100.0000, -100.0000, -100.0000, -100.0000, -100.0000],
[ 0.6722, 0.4248, 0.6839, 0.9222, 0.8995],
[ 0.6662, 0.9287, 0.3097, 0.6207, 0.5590],
[ 0.6176, 0.4586, 0.5354, 0.6958, 0.4959]]]])
Process finished with exit code 0
好记性不如烂键盘---点滴、积累、进步!