KD 树(K-Dimensional Tree)在实现示例
以下是一个简单的KD树(K-Dimensional Tree)在C++中的实现示例,用于处理二维点数据的情况。KD树是一种用于对k维空间中的数据点进行划分的数据结构,常用于快速查找最近邻点等操作。
#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>
using namespace std;
// 定义点结构体
struct Point {
double x;
double y;
Point(double _x = 0, double _y = 0) : x(_x), y(_y) {}
};
// KD树节点结构体
struct KDNode {
Point point;
KDNode* left;
KDNode* right;
KDNode(const Point& p) : point(p), left(nullptr), right(nullptr) {}
};
// 计算两点之间的欧几里得距离
double distance(const Point& p1, const Point& p2) {
return sqrt((p1.x - p2.x) * (p1.x - p2.x) + (p1.y - p2.y) * (p1.y - p2.y));
}
// 构建KD树
KDNode* buildKDTree(const vector<Point>& points, int depth) {
if (points.empty()) {
return nullptr;
}
int k = 2; // 二维空间
int axis = depth % k;
// 根据当前划分维度对 points 进行排序
vector<Point> sortedPoints = points;
if (axis == 0) {
sort(sortedPoints.begin(), sortedPoints.end(), [](const Point& p1, const Point& p2) {
return p1.x < p2.x;
});
} else {
sort(sortedPoints.begin(), sortedPoints.end(), [](const Point& p1, const Point& p2) {
return p1.y < p2.y;
});
}
int medianIndex = sortedPoints.size() / 2;
KDNode* node = new KDNode(sortedPoints[medianIndex]);
vector<Point> leftPoints(sortedPoints.begin(), sortedPoints.begin() + medianIndex);
vector<Point> rightPoints(sortedPoints.begin() + medianIndex + 1, sortedPoints.end());
node->left = buildKDTree(leftPoints, depth + 1);
node->right = buildKDTree(rightPoints, depth + 1);
return node;
}
// 查找KD树中距离目标点最近的点
Point findNearestNeighbor(KDNode* root, const Point& target, int depth) {
if (root == nullptr) {
return Point();
}
int k = 2;
int axis = depth % k;
KDNode* nextBranch = nullptr;
KDNode* otherBranch = nullptr;
if ((axis == 0 && target.x < root->point.x) || (axis == 1 && target.y < root->point.y)) {
nextBranch = root->left;
otherBranch = root->right;
} else {
nextBranch = root->right;
otherBranch = root->left;
}
Point best = findNearestNeighbor(nextBranch, target, depth + 1);
double bestDistance = distance(best, target);
if (distance(root->point, target) < bestDistance) {
best = root->point;
bestDistance = distance(root->point, target);
}
double splitDistance = 0;
if (axis == 0) {
splitDistance = fabs(root->point.x - target.x);
} else {
splitDistance = fabs(root->point.y - target.y);
}
if (splitDistance < bestDistance) {
Point otherBest = findNearestNeighbor(otherBranch, target, depth + 1);
double otherBestDistance = distance(otherBest, target);
if (otherBestDistance < bestDistance) {
best = otherBest;
}
}
return best;
}
你可以使用以下方式来测试上述代码:
int main() {
vector<Point> points = {
Point(2, 3),
Point(5, 4),
Point(9, 6),
Point(4, 7),
Point(8, 1),
Point(7, 2)
};
KDNode* root = buildKDTree(points, 0);
Point target(6, 3);
Point nearestNeighbor = findNearestNeighbor(root, target, 0);
cout << "目标点 (" << target.x << ", " << target.y << ") 最近邻点为 (" << nearestNeighbor.x << ", " << nearestNeighbor.y << ")" << endl;
return 0;
}
在上述代码中:
Point
结构体用于表示二维空间中的点。KDNode
结构体用于表示KD树的节点,包含一个点以及左右子节点指针。distance
函数用于计算两点之间的欧几里得距离。buildKDTree
函数通过递归的方式构建KD树,根据当前划分维度对数据点进行排序并选取中位数作为节点,然后分别构建左右子树。findNearestNeighbor
函数用于在KD树中查找距离目标点最近的点,通过比较当前节点、子树中的点以及跨分割平面的点来确定最近邻点。
请注意,这只是一个简单的示例,实际应用中可能需要根据具体需求进行更多的优化和扩展,比如处理更高维度的数据、更高效的内存管理等。