[学习笔记]K-D树
简述
K-D树的本质是一棵二叉查找树,但每一层划分的标准变为某一维度,以垂直于某一坐标轴的超平面将当前区域划分为两个区域
但和二叉查找树不同的是K-D树每个节点储存了一个样本,简单理解为每个节点都代表插入的一个点
构建
考虑当前区域按第\(dim\)维划分,为了让树尽量平衡,将这个区域内所有点按第\(dim\)维排序后,从第\(mid\)个点处划分最优,\(C++\)可以用\(nth\_element\)快速求出这个点,顺便把排在这个点之前的点全都放到它左边
当前节点存下第\(mid\)个点的信息,然后递归的构建左子树和右子树
代码(二维)如下:
KD_Tree::Node *KD_Tree::build(int l, int r, int dim) {
if (l > r) return NULL;
Node *rt = new Node();
int mid = (l + r + 1) >> 1;
std::nth_element(pt + l, pt + mid, pt + r + 1, cmp[dim]);
rt->cur = pt[mid];
rt->ls = build(l, mid - 1, dim ^ 1);
rt->rs = build(mid + 1, r, dim ^ 1);
push_up(rt);
return rt;
}
其中\(l\)到\(r\)的点是区域内的点,\(dim\)代表当前维度
查询
以查询距点\(p\)最近的点为例
依然是递归查找
首先查看当前节点所代表的点是否更优,然后估计左右子树距这个点可能的最近距离
如果估计值更优,就继续往子树查找,否则就不用找下去了
当左右子树都更优时,显然先找较优的子树不会更差,若找完较优的子树后另一棵子树还有可能更优,再到查找另一棵子树中查找
代码(二维)如下:
LL queryMin(KD_Tree *rt, Point &tar) {
LL res = dist(rt->data, tar);
if (!res) res = INF;
LL dl = (rt->son[0] ? getMin(rt->son[0], tar) : INF), dr = (rt->son[1] ? getMin(rt->son[1], tar) : INF);
if (dl > dr) {
if (dr < res) res = std::min(res, queryMin(rt->son[1], tar));
if (dl < res) res = std::min(res, queryMin(rt->son[0], tar));
} else {
if (dl < res) res = std::min(res, queryMin(rt->son[0], tar));
if (dr < res) res = std::min(res, queryMin(rt->son[1], tar));
}
return res;
}
节点记录的东西
通常需要记录左右儿子、每一维坐标的最大及最小值、所代表的点的信息,划分的维度可以记录,也可以递归过程中处理出来
其它信息根据题目要求
总结
看起来非常高端的K-D树其实是很朴素的搜索加上很神奇的剪枝
两句话概括就是:
- 构建——循环用每一维构建二叉查找树,记录信息
- 查询——如果子树中可能有更优解就进入查找,否则退出,优先查找可能的解更优的那颗子树
是不是异常简单?(然而我学了半个星期才打出模板)
代码
最近点对的找不到了,放个这题[CQOI2016]K远点对吧,也挺裸的
#include <cstdio>
#include <iostream>
#include <cstring>
#include <algorithm>
#include <queue>
#define sqr(x) ((x) * (x))
#define MAXN 100005
typedef long long LL;
struct Point {
LL cor[2];
Point(LL d0 = 0, LL d1 = 1) { cor[0] = d0, cor[1] = d1; }
} pt[MAXN];
struct KD_Tree {
struct Node {
Node *ls, *rs;
Point cur;
LL maxc[2], minc[2];
} * root;
void push_up(Node *);
Node* build(int, int, int);
void query(Node *, int, const Point &);
} kd;
int N, K;
std::priority_queue<LL, std::vector<LL>, std::greater<LL> > que;
bool cmp0(const Point &, const Point &);
bool cmp1(const Point &, const Point &);
bool (* cmp[])(const Point &, const Point &) = {cmp0, cmp1};
inline LL dist(const Point &p1, const Point &p2) { return sqr(p1.cor[0] - p2.cor[0]) + sqr(p1.cor[1] - p2.cor[1]); }
inline LL max_dist(const Point &p, KD_Tree::Node *rt) {
return std::max(sqr(p.cor[0] - rt->maxc[0]), sqr(p.cor[0] - rt->minc[0]))
+ std::max(sqr(p.cor[1] - rt->maxc[1]), sqr(p.cor[1] - rt->minc[1]));
}
int main() {
std::ios::sync_with_stdio(false);
std::cin >> N >> K;
K <<= 1;
for (int i = 1; i <= K; ++i) que.push(0);
for (int i = 1; i <= N; ++i)
std::cin >> pt[i].cor[0] >> pt[i].cor[1];
kd.root = kd.build(1, N, 0);
for (int i = 1; i <= N; ++i)
kd.query(kd.root, 0, pt[i]);
std::cout << que.top() << std::endl;
return 0;
}
inline void max(LL &a, LL b) { a = std::max(a, b); }
inline void min(LL &a, LL b) { a = std::min(a, b); }
bool cmp0(const Point &p1, const Point &p2) { return p1.cor[0] < p2.cor[0]; }
bool cmp1(const Point &p1, const Point &p2) { return p1.cor[1] < p2.cor[1]; }
KD_Tree::Node *KD_Tree::build(int l, int r, int dim) {
if (l > r) return NULL;
Node *rt = new Node();
int mid = (l + r + 1) >> 1;
std::nth_element(pt + l, pt + mid, pt + r + 1, cmp[dim]);
rt->cur = pt[mid];
rt->ls = build(l, mid - 1, dim ^ 1);
rt->rs = build(mid + 1, r, dim ^ 1);
push_up(rt);
return rt;
}
void KD_Tree::query(Node *rt, int dim, const Point &p) {
if (!rt) return;
if (dist(p, rt->cur) > que.top()) { que.pop(); que.push(dist(p, rt->cur)); }
LL dl = -0x3f3f3f3f3f3f3f3f, dr = -0x3f3f3f3f3f3f3f3f;
if (rt->ls) dl = max_dist(p, rt->ls);
if (rt->rs) dr = max_dist(p, rt->rs);
if (dl > dr) {
if (dl > que.top()) query(rt->ls, dim ^ 1, p);
if (dr > que.top()) query(rt->rs, dim ^ 1, p);
} else {
if (dr > que.top()) query(rt->rs, dim ^ 1, p);
if (dl > que.top()) query(rt->ls, dim ^ 1, p);
}
}
void KD_Tree::push_up(Node *rt) {
rt->maxc[0] = rt->minc[0] = rt->cur.cor[0];
rt->maxc[1] = rt->minc[1] = rt->cur.cor[1];
if (rt->ls) {
max(rt->maxc[0], rt->ls->maxc[0]); min(rt->minc[0], rt->ls->minc[0]);
max(rt->maxc[1], rt->ls->maxc[1]); min(rt->minc[1], rt->ls->minc[1]);
}
if (rt->rs) {
max(rt->maxc[0], rt->rs->maxc[0]); min(rt->minc[0], rt->rs->minc[0]);
max(rt->maxc[1], rt->rs->maxc[1]); min(rt->minc[1], rt->rs->minc[1]);
}
}
//Rhein_E