nms

void DstEngine::nms(std::shared_ptr<DeviceStreamData> &input,std::shared_ptr<FaceInfo> &output, float nmsthreshold)
{
float iou = 0.f;
float x1, y1, x2, y2;
float w, h, intersection;
std::vector<int> Picks;

std::multimap<float, int > vscores;
for (int i = 0; i < input.size() ; ++i) {
vscores.insert(std::make_pair(input[i].score, i));
}

while (vscores.size() > 0){
int first = vscores.rbegin()->second; // multimap会自动对键升序排列
Picks.push_back(first);

for (std::multimap<float, int>::iterator it = vscores.begin(); it != vscores.end(); ) {
int current_idx = it->second;
x1 = std::max(input[current_idx].x1, input[first].x1);
y1 = std::max(input[current_idx].y1, input[first].y1);
x2 = std::min(input[current_idx].x2, input[first].x2);
y2 = std::min(input[current_idx].y2, input[first].y2);

w = ((x2 - x1 + 1) > 0) ? (x2 - x1 + 1) : 0;
h = ((y2 - y1 + 1) > 0) ? (y2 - y1 + 1) : 0;
intersection = w * h;

if (nms_type == nms_union)
iou = intersection / (input[current_idx].area + input[first].area - intersection);
else if (nms_type == nms_min)
iou = intersection / std::min(input[current_idx].area, input[first].area);
else
std::cout << "[nms] | something wrong" << std::endl;

if (iou > nmsthreshold) {
it = vscores.erase(it);
} else{
it++;
}
}
}

output.resize(Picks.size());
for (int j = 0; j < Picks.size(); ++j) {
output[j] = input[Picks[j]];
}
}

posted @ 2020-06-16 15:54  闪光123  阅读(242)  评论(0编辑  收藏  举报