KD 树(K-Dimensional Tree)在实现示例

以下是一个简单的KD树(K-Dimensional Tree)在C++中的实现示例,用于处理二维点数据的情况。KD树是一种用于对k维空间中的数据点进行划分的数据结构,常用于快速查找最近邻点等操作。

#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>

using namespace std;

// 定义点结构体
struct Point {
    double x;
    double y;

    Point(double _x = 0, double _y = 0) : x(_x), y(_y) {}
};

// KD树节点结构体
struct KDNode {
    Point point;
    KDNode* left;
    KDNode* right;

    KDNode(const Point& p) : point(p), left(nullptr), right(nullptr) {}
};

// 计算两点之间的欧几里得距离
double distance(const Point& p1, const Point& p2) {
    return sqrt((p1.x - p2.x) * (p1.x - p2.x) + (p1.y - p2.y) * (p1.y - p2.y));
}

// 构建KD树
KDNode* buildKDTree(const vector<Point>& points, int depth) {
    if (points.empty()) {
        return nullptr;
    }

    int k = 2; // 二维空间
    int axis = depth % k;

    // 根据当前划分维度对 points 进行排序
    vector<Point> sortedPoints = points;
    if (axis == 0) {
        sort(sortedPoints.begin(), sortedPoints.end(), [](const Point& p1, const Point& p2) {
            return p1.x < p2.x;
        });
    } else {
        sort(sortedPoints.begin(), sortedPoints.end(), [](const Point& p1, const Point& p2) {
            return p1.y < p2.y;
        });
    }

    int medianIndex = sortedPoints.size() / 2;
    KDNode* node = new KDNode(sortedPoints[medianIndex]);

    vector<Point> leftPoints(sortedPoints.begin(), sortedPoints.begin() + medianIndex);
    vector<Point> rightPoints(sortedPoints.begin() + medianIndex + 1, sortedPoints.end());

    node->left = buildKDTree(leftPoints, depth + 1);
    node->right = buildKDTree(rightPoints, depth + 1);

    return node;
}

// 查找KD树中距离目标点最近的点
Point findNearestNeighbor(KDNode* root, const Point& target, int depth) {
    if (root == nullptr) {
        return Point();
    }

    int k = 2;
    int axis = depth % k;

    KDNode* nextBranch = nullptr;
    KDNode* otherBranch = nullptr;

    if ((axis == 0 && target.x < root->point.x) || (axis == 1 && target.y < root->point.y)) {
        nextBranch = root->left;
        otherBranch = root->right;
    } else {
        nextBranch = root->right;
        otherBranch = root->left;
    }

    Point best = findNearestNeighbor(nextBranch, target, depth + 1);
    double bestDistance = distance(best, target);

    if (distance(root->point, target) < bestDistance) {
        best = root->point;
        bestDistance = distance(root->point, target);
    }

    double splitDistance = 0;
    if (axis == 0) {
        splitDistance = fabs(root->point.x - target.x);
    } else {
        splitDistance = fabs(root->point.y - target.y);
    }

    if (splitDistance < bestDistance) {
        Point otherBest = findNearestNeighbor(otherBranch, target, depth + 1);
        double otherBestDistance = distance(otherBest, target);

        if (otherBestDistance < bestDistance) {
            best = otherBest;
        }
    }

    return best;
}

你可以使用以下方式来测试上述代码:

int main() {
    vector<Point> points = {
        Point(2, 3),
        Point(5, 4),
        Point(9, 6),
        Point(4, 7),
        Point(8, 1),
        Point(7, 2)
    };

    KDNode* root = buildKDTree(points, 0);

    Point target(6, 3);
    Point nearestNeighbor = findNearestNeighbor(root, target, 0);

    cout << "目标点 (" << target.x << ", " << target.y << ") 最近邻点为 (" << nearestNeighbor.x << ", " << nearestNeighbor.y << ")" << endl;

    return 0;
}

在上述代码中:

  1. Point结构体用于表示二维空间中的点。
  2. KDNode结构体用于表示KD树的节点,包含一个点以及左右子节点指针。
  3. distance函数用于计算两点之间的欧几里得距离。
  4. buildKDTree函数通过递归的方式构建KD树,根据当前划分维度对数据点进行排序并选取中位数作为节点,然后分别构建左右子树。
  5. findNearestNeighbor函数用于在KD树中查找距离目标点最近的点,通过比较当前节点、子树中的点以及跨分割平面的点来确定最近邻点。

请注意,这只是一个简单的示例,实际应用中可能需要根据具体需求进行更多的优化和扩展,比如处理更高维度的数据、更高效的内存管理等。

posted @ 2024-11-10 10:58  MarsCactus  阅读(10)  评论(0编辑  收藏  举报