C++实现nms
怎么理解nms?
非极大值抑制,简单的说就是给出一大堆bbox和相应的得分,对于其中区域重合的box,如果两个box重合部分大于设定的theshold,就抛弃小的那个,直到所有的box
都判定完了。
struct anchor_box
{
float x1;
float y1;
float x2;
float y2;
};
struct FaceDetectInfo
{
float score;
anchor_box rect;
FacePts pts; // 这个标记点在nms中没啥用
};
std::vector<FaceDetectInfo> RetinaFace::nms(std::vector<FaceDetectInfo>& bboxes, float threshold)
{
std::vector<FaceDetectInfo> bboxes_nms;
std::sort(bboxes.begin(), bboxes.end(), CompareBBox);
int32_t select_idx = 0;
int32_t num_bbox = static_cast<int32_t>(bboxes.size());
std::vector<int32_t> mask_merged(num_bbox, 0);
bool all_merged = false;
while (!all_merged) {
while (select_idx < num_bbox && mask_merged[select_idx] == 1)
select_idx++;
if (select_idx == num_bbox) {
all_merged = true;
continue;
}
bboxes_nms.push_back(bboxes[select_idx]);
mask_merged[select_idx] = 1;
anchor_box select_bbox = bboxes[select_idx].rect;
float area1 = static_cast<float>((select_bbox.x2 - select_bbox.x1 + 1) * (select_bbox.y2 - select_bbox.y1 + 1));
float x1 = static_cast<float>(select_bbox.x1);
float y1 = static_cast<float>(select_bbox.y1);
float x2 = static_cast<float>(select_bbox.x2);
float y2 = static_cast<float>(select_bbox.y2);
select_idx++;
for (int32_t i = select_idx; i < num_bbox; i++) {
if (mask_merged[i] == 1)
continue;
anchor_box& bbox_i = bboxes[i].rect;
float x = std::max<float>(x1, static_cast<float>(bbox_i.x1));
float y = std::max<float>(y1, static_cast<float>(bbox_i.y1));
float w = std::min<float>(x2, static_cast<float>(bbox_i.x2)) - x + 1; //<- float 型不加1
float h = std::min<float>(y2, static_cast<float>(bbox_i.y2)) - y + 1;
if (w <= 0 || h <= 0)
continue;
float area2 = static_cast<float>((bbox_i.x2 - bbox_i.x1 + 1) * (bbox_i.y2 - bbox_i.y1 + 1));
float area_intersect = w * h;
if (static_cast<float>(area_intersect) / (area1 + area2 - area_intersect) > threshold) {
mask_merged[i] = 1;
}
}
}
return bboxes_nms;
}
这段代码来自retinaface mnet tensorrt实现中的一个实现,具体地址我忘了。我觉得这段代码可优化空间很大。
代码思路很简单。写段伪代码描述下
input: vector<bbox> bboxs, float threshold, bool all_merged = false
whie (!all_merged):
select it in bboxs where it not merge, if all bboxs have merged, set all_merged true,then break loop.
for other in bboxs:
if 'other' not merge and IOU('it', 'other') > threshold:
set 'other' merged
else continue
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 10年+ .NET Coder 心语,封装的思维:从隐藏、稳定开始理解其本质意义
· .NET Core 中如何实现缓存的预热?
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· 10年+ .NET Coder 心语 ── 封装的思维:从隐藏、稳定开始理解其本质意义
· 地球OL攻略 —— 某应届生求职总结
· 提示词工程——AI应用必不可少的技术
· Open-Sora 2.0 重磅开源!
· 周边上新:园子的第一款马克杯温暖上架