K-D Tree
1 引入
\(\text{K-D Tree}\) 是一种高效处理 \(k\) 维空间信息的数据结构。具体的讲,它维护了 \(k\) 维空间中 \(n\) 个点的信息,并且拥有二叉搜索树的形态。
在 \(n\) 远大于 \(2^k\) 时 \(\text{K-D Tree}\) 有较好的时间效率,一般情况下,我们会取 \(k=2\) 即一个二维平面来应用它。
2 基本操作
2.1 建树
我们希望这棵 \(\text{K-D Tree}\) 尽可能是一颗平衡树,那么我们对于所有点对,肯定要选择某一维度上值的中位数作为当前的根,这样可以保证两边的点的数量尽可能接近。
但是如果每一次都按照一维来排序的话,可能会出现所有点在这一维上都很接近,但是在别的维上相差很远的情况,复杂度会爆炸。那么我们就需要采取一个优化措施,我们每一次划分是轮流按照 \(k\) 个维度的每一个来进行划分,这样可以保证每一维都被划分到。如此,可以保证建出的树高为 \(\log n+O(1)\)。
现在的问题是找出区间中某一维度上的中位数的值,如果直接排序的话总复杂度是 \(O(n\log^2 n)\) 的,不过我们实际上只需要找出中位数,并且将小于它的放到左边、大于它的放到右边即可。实际上这可以用 nth_element
来简单实现,复杂度是 \(O(n\log n)\)。
基础代码如下:
void build(int &p, int l, int r, int typ) {
if(l > r) return p = 0, void();
int mid = (l + r) >> 1;
nth_element(a + l, a + mid, a + r + 1, [typ](node x, node y){return x.v[typ] < y.v[typ];});//按照当前维排序,找中位数
p = ++tot;
t[p].v[0] = a[mid].v[0], t[p].v[1] = a[mid].v[1];//赋值
build(lp, l, mid - 1, typ ^ 1), build(rp, mid + 1, r, typ ^ 1);
}
那么此时我们分析一下 \(\text{K-D Tree}\) 的建树过程,不难发现这样一点:一个子树内的所有点恰好对应一个矩形。那么我们可以通过维护子树内每一维坐标的极值来确定该矩形的大小,所以可以得到最终的代码:
struct KD_Tree {
int l, r, v[2], mn[2], mx[2];
// 当前点坐标 坐标最小值 坐标最大值
}t[Maxn];
int tot = 0;
#define lp t[p].l
#define rp t[p].r
void pushup(int p) {
for(int i = 0; i < 2; i++) {
t[p].mn[i] = t[p].mx[i] = t[p].v[i];
if(lp) {
t[p].mn[i] = min(t[p].mn[i], t[lp].mn[i]);
t[p].mx[i] = max(t[p].mx[i], t[lp].mx[i]);
}
if(rp) {
t[p].mn[i] = min(t[p].mn[i], t[rp].mn[i]);
t[p].mx[i] = max(t[p].mx[i], t[rp].mx[i]);
}
}
}
void build(int &p, int l, int r, int typ) {
if(l > r) return p = 0, void();
int mid = (l + r) >> 1;
nth_element(a + l, a + mid, a + r + 1, [typ](node x, node y){return x.v[typ] < y.v[typ];});
p = ++tot;
t[p].v[0] = a[mid].v[0], t[p].v[1] = a[mid].v[1];
build(lp, l, mid - 1, typ ^ 1), build(rp, mid + 1, r, typ ^ 1);
pushup(p);//上传标记
}
2.2 插入与删除
如果我们维护的点集会发生变动,此时静态建树的 \(\text{K-D Tree}\) 的复杂度就无法得到保证。所以我们需要找出一种动态建树的方式。遗憾的是,常见于平衡树的维护平衡的两个操作,即旋转和随机优先级,都不能运用到 \(\text{K-D Tree}\) 上,所以我们通常采用下面两种方式。
2.2.1 根号重构
我们可以想到的是利用替罪羊树的重构套路对 \(\text{K-D Tree}\) 进行重构(即设置一个平衡因子 \(\alpha\)),但是实际上利用替罪羊树进行重构只能保证高度是 \(O(\log n)\),不是严格的 \(\log n+O(1)\),所以查询复杂度可能会退化。但是一般情境下替罪羊式重构也足够通过。
考虑另一种方式,设定一个阈值 \(B\),每次插入的时候直接从根节点开始和每个节点比较并向下递归。当插入次数达到 \(B\) 的时候暴力重构整棵树。删除时仍然采用惰性删除,当树内删除数量达到 \(B\) 的时候继续暴力重构。
如此,当我们取到 \(B=O(\sqrt {n\log n})\) 的时候复杂度最优,为单次均摊 \(O(\sqrt{n\log n})\)。
2.2.2 二进制分组
如果仅仅要求插入,那么这种做法是更优的。我们维护若干棵大小为 \(2^i\) 的 \(\text{K-D Tree}\),满足这些树的大小之和为 \(n\)。
插入的时候,新增一棵大小为 \(2^0\) 的 \(\text{K-D Tree}\),然后不断将相同大小的 \(\text{K-D Tree}\) 进行合并。实际操作的时候可以先将可以合并在一起的所有树拍扁,然后只需要重构一次即可。
这样做的总复杂度是均摊 \(O(n\log^2 n)\) 的,较上一种做法更优秀。
代码如下:
namespace KDT {
//...
int tot = 0;
int trs[Maxn], top;//拍扁的时候会删除节点,可以用垃圾桶来节省空间
void del(int &p) {//删除节点
trs[++top] = p;
t[p] = {0, 0, 0, 0, 0, 0, 0, 0};
p = 0;
}
int newnode() {
return top ? trs[top--] : ++tot;
}
//...
void append(int &p) {//拍扁重构
if(!p) return ;
a[++cnt] = {t[p].v[0], t[p].v[1]};//记录下当前点
append(lp), append(rp);
del(p);//删除
}
//...
}
int main() {
//...
for(int i = 1; i <= n; i++) {
int x, y;
cin >> x >> y;
for(int j = 0; j < 20; j++) {//每一个根查询一边
KDT::query(rt[j], x, y);
}
a[cnt = 1] = {x, y};//开始重构
for(int j = 0; j < 20; j++) {
if(!rt[j]) {//当前大小为 2^j 的树还没有建,无法合并,在这里重建树
KDT::build(rt[j], 1, cnt, 0);
break;
}
else {
KDT::append(rt[j]);//拍平重构
}
}
}
//...
return 0;
}
3 查询操作
3.1 矩阵查询
我们在查询矩阵中所有点的信息时,按照传统的方式去进行递归。如果当前子树对应的矩形和目标矩形无交点,则不继续搜索;否则如果被目标矩形完全包含,直接返回整个子树的信息即可;否则先判断当前节点是否合法,然后再递归下去找答案。
可以证明,这样做的复杂度是单次 \(O(n^{1-\frac 1k})\) 的,证明如下:
考虑将每个节点对应的矩阵分 \(3\) 类:
- 与目标矩阵无交点。
- 完全被目标矩阵包含。
- 与目标矩阵有部分交集。
显然前两种如果递归到我们会直接返回,所以只需要考虑第三种矩阵。而第三种矩阵又分为完全包含目标矩阵的部分和剩下的部分。前者显然最多只有 \(O(\log n)\) 个。
现在考虑后者。我们对于一个节点来看,我们在它的儿子和孙子处分别对 \(x,y\) 坐标进行了一次划分,共划分为了 \(4\) 个子矩阵。考虑查询矩阵的每一条边,此时它经过了几个子矩阵,就代表它还要访问那些子树。显然可以发现,对于任意一条边来讲,它最多经过 \(2\) 个这样的子矩阵。
设当前子树大小为 \(n\),由于我们建树时保证了子树大小每一次减半,所以子矩阵大小应该是 \(\dfrac n4\) 的。于是有以下递归式:
\[T(n)=2T(\dfrac n4)+O(1) \]根据主定理可知 \(T(n)=O(\sqrt n)\)。将其推广至 \(k\) 维可得递归式为 \(T(n)=2^{k-1} T(\dfrac{n}{2^k})+O(1)\),可得 \(T(n)=O(n^{1-\frac 1k})\)。
3.2 邻域查询
邻域查询可以求出平面上一个点的最近 / 最远点。值得注意的是 \(\text{K-D Tree}\) 求解这个问题的复杂度仍是最坏 \(O(n)\) 的,但是在随机数据下表现为均摊 \(O(\log n)\),并且大多数情况下表现较为优秀且很少有人卡,所以不失为一种好的骗分技巧。
假设现在我们要找出离当前点最近的点,我们暴力遍历 \(\text{K-D Tree}\) 上的每一个节点,然后进行剪枝。我们可以对每个子树对应的矩阵设计一个估价函数,例如用查询点到这个矩阵的最短距离作为估价函数,然后进行启发式搜索,先搜索估价函数较小的子树的答案。
同理还可以进行最优性剪枝,如果当前节点的估价函数都比当前答案大,那么子树内不可能有更优的答案,直接返回即可。
4 例题
例 1 [SDOI2010] 捉迷藏
此题就是求最近最远点对的题目。暴力枚举每一个点,求出不是自己的离自己最近和最远的点的曼哈顿距离即可。
代码如下:
#include <bits/stdc++.h>
using namespace std;
const int Maxn = 2e5 + 5;
const int Inf = 2e9;
int n;
struct node {
int v[2];
}a[Maxn];
int rt;
int ans1 = Inf, ans2 = -Inf;
namespace KDT {
struct KD_Tree {
int l, r, v[2], mn[2], mx[2];
}t[Maxn];
int tot = 0;
#define lp t[p].l
#define rp t[p].r
void pushup(int p) {
for(int i = 0; i < 2; i++) {
t[p].mn[i] = t[p].mx[i] = t[p].v[i];
if(lp) {
t[p].mn[i] = min(t[p].mn[i], t[lp].mn[i]);
t[p].mx[i] = max(t[p].mx[i], t[lp].mx[i]);
}
if(rp) {
t[p].mn[i] = min(t[p].mn[i], t[rp].mn[i]);
t[p].mx[i] = max(t[p].mx[i], t[rp].mx[i]);
}
}
}
void build(int &p, int l, int r, int typ) {
if(l > r) return p = 0, void();
int mid = (l + r) >> 1;
nth_element(a + l, a + mid, a + r + 1, [typ](node x, node y){return x.v[typ] < y.v[typ];});
p = ++tot;
t[p].v[0] = a[mid].v[0], t[p].v[1] = a[mid].v[1];
build(lp, l, mid - 1, typ ^ 1), build(rp, mid + 1, r, typ ^ 1);
pushup(p);
}
int dis(int x1, int y1, int x2, int y2) {//距离
return abs(x1 - x2) + abs(y1 - y2);
}
int fmin(int p, int x, int y) {//最小值估价函数
int res = 0;
if(x < t[p].mn[0]) res += t[p].mn[0] - x;
if(x > t[p].mx[0]) res += x - t[p].mx[0];
if(y < t[p].mn[1]) res += t[p].mn[1] - y;
if(y > t[p].mx[1]) res += y - t[p].mx[1];
return res;
}
int fmax(int p, int x, int y) {//最大值估价函数
int res = 0;
res += max(abs(x - t[p].mn[0]), abs(x - t[p].mx[0]));
res += max(abs(y - t[p].mn[1]), abs(y - t[p].mx[1]));
return res;
}
void qmin(int p, int x, int y) {
if(!p) return;
if(!(x == t[p].v[0] && y == t[p].v[1])) ans1 = min(ans1, dis(x, y, t[p].v[0], t[p].v[1]));//注意不能是自己本身
int vl = Inf, vr = Inf;
if(lp) vl = fmin(lp, x, y);
if(rp) vr = fmin(rp, x, y);
if(vl < vr) {//启发式搜索,先搜更小的
if(vl < ans1) qmin(lp, x, y);
if(vr < ans1) qmin(rp, x, y);
}
else {
if(vr < ans1) qmin(rp, x, y);
if(vl < ans1) qmin(lp, x, y);
}
}
void qmax(int p, int x, int y) {
if(!p) return ;
ans2 = max(ans2, dis(x, y, t[p].v[0], t[p].v[1]));
int vl = -Inf, vr = -Inf;
if(lp) vl = fmax(lp, x, y);
if(rp) vr = fmax(rp, x, y);
if(vl > vr) {
if(vl > ans2) qmax(lp, x, y);
if(vr > ans2) qmax(rp, x, y);
}
else {
if(vr > ans2) qmax(rp, x, y);
if(vl > ans2) qmax(lp, x, y);
}
}
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
cin >> n;
for(int i = 1; i <= n; i++) {
cin >> a[i].v[0] >> a[i].v[1];
}
KDT::build(rt, 1, n, 0);
int ans = Inf;
for(int i = 1; i <= n; i++) {
ans1 = Inf, ans2 = -Inf;
KDT::qmin(rt, a[i].v[0], a[i].v[1]);
KDT::qmax(rt, a[i].v[0], a[i].v[1]);
ans = min(ans, ans2 - ans1);
}
cout << ans << '\n';
return 0;
}
例 2 巧克力王国
题目即求所有 \(ax+by<c\) 的巧克力的 \(h\) 之和。仍然考虑 \(\text{K-D Tree}\),将所有 \(x,y\) 扔到二维平面上建树,然后仍然采用类似矩阵查询的方式来完成。我们单次看矩阵求出的 \(ax+by\) 的最大值和最小值,如果最大值 \(<c\) 则直接加上整个矩阵,如果最小值 \(\ge c\) 则返回,否则继续向下递归求解。
但是遗憾的是,此题的 \(\text{K-D Tree}\) 并不是普通矩阵查询的 \(O(n^{1-\frac 1k})\) 的复杂度,因为这并不是严格意义上的矩阵查询。事实上,它的最坏复杂度仍然是 \(O(n)\)。但是题目中保证了数据随机,因此可以通过。
例 3 [国家集训队] JZPFAR
发现题目现在要求离当前点第 \(k\) 远的点,但是 \(k\) 很小,所以可以考虑用一个小根堆存下当前所有的答案。然后我们仍然采用启发式搜索的方式,只有当小根堆大小 \(<k\) 或者估价函数值比堆顶大的时候才去递归,并且取两个儿子中较大的先递归。
由于保证数据随机,所以 \(\text{K-D Tree}\) 可以通过。
例 4 [BZOJ4605] 崂山白花蛇草水
发现这道题就是一个单点加、矩阵第 \(k\) 大。如果只有单点加和矩阵查询的话我们可以用 \(\text{K-D Tree}\) 做到 \(O(\sqrt n)\)。但是现在我们要求第 \(k\) 大,自然想到利用权值线段树来辅助求解。所以最后不难想到利用树套树来解决这个问题。
接下来我们有两种方法来维护:
- 外层维护下标,内层维护权值。即 \(\text{K-D Tree}\) 套权值线段树。
- 外层维护权值,内层维护下标。即权值线段树套 \(\text{K-D Tree}\)。
第一种做法比较困难,因为这样做的话权值线段树的合并要求可持久化,并且常数也过大。我们采用第二种方法即可,查询时在权值线段树上二分,用 \(\text{K-D Tree}\) 的矩阵查询来求出点的个数然后判断向哪个儿子走即可。
修改的总复杂度是 \(O(q\log^2 q\log V)\),查询复杂度 \(O(q\sqrt q \log V)\)。
采用二进制分组的代码如下:
#include <bits/stdc++.h>
#define il inline
using namespace std;
const int Maxn = 1e5 + 5;
const int Inf = 2e9;
const int N = 1e9;
int n, q;
struct node {
int v[2];
}a[Maxn];
int cnt = 0;
namespace KDT {
struct KD_Tree {
int l, r, v[2], mn[2], mx[2], siz;
}t[Maxn * 30];
#define lp t[p].l
#define rp t[p].r
int tot = 0;
int trs[Maxn * 30], top;
int rt[Maxn * 30][18];
il void del(int &p) {
trs[++top] = p;
t[p] = {0, 0, 0, 0, 0, 0, 0, 0, 0};
p = 0;
}
il int newnode() {
return top ? trs[top--] : ++tot;
}
il void pushup(int p) {
t[p].siz = t[lp].siz + t[rp].siz + 1;
t[p].mn[0] = t[p].mx[0] = t[p].v[0];
t[p].mn[1] = t[p].mx[1] = t[p].v[1];
if(lp) {
t[p].mn[0] = min(t[p].mn[0], t[lp].mn[0]);
t[p].mx[0] = max(t[p].mx[0], t[lp].mx[0]);
t[p].mn[1] = min(t[p].mn[1], t[lp].mn[1]);
t[p].mx[1] = max(t[p].mx[1], t[lp].mx[1]);
}
if(rp) {
t[p].mn[0] = min(t[p].mn[0], t[rp].mn[0]);
t[p].mx[0] = max(t[p].mx[0], t[rp].mx[0]);
t[p].mn[1] = min(t[p].mn[1], t[rp].mn[1]);
t[p].mx[1] = max(t[p].mx[1], t[rp].mx[1]);
}
}
void build(int &p, int l, int r, int typ) {
if(l > r) return ;
int mid = (l + r) >> 1;
nth_element(a + l, a + mid, a + r + 1, [typ](node x, node y){return x.v[typ] < y.v[typ];});
p = newnode();
t[p].v[0] = a[mid].v[0], t[p].v[1] = a[mid].v[1];
build(lp, l, mid - 1, typ), build(rp, mid + 1, r, typ);
pushup(p);
}
void append(int &p) {
if(!p) return ;
a[++cnt] = {t[p].v[0], t[p].v[1]};
append(lp), append(rp);
del(p);
}
il bool chkin(int x, int y, int x1, int y1, int x2, int y2) {
return (x1 <= x && x <= x2 && y1 <= y && y <= y2);
}
il bool checkin(int x1, int y1, int x2, int y2, int x3, int y3, int x4, int y4) {
return (x3 <= x1 && x2 <= x4 && y3 <= y1 && y2 <= y4);
}
il bool checkout(int x1, int y1, int x2, int y2, int x3, int y3, int x4, int y4) {
return (x3 > x2 || x1 > x4 || y3 > y2 || y1 > y4);
}
int query(int p, int x1, int y1, int x2, int y2) {
if(!p) return 0;
if(checkin(t[p].mn[0], t[p].mn[1], t[p].mx[0], t[p].mx[1], x1, y1, x2, y2)) return t[p].siz;//被包含直接返回
if(checkout(t[p].mn[0], t[p].mn[1], t[p].mx[0], t[p].mx[1], x1, y1, x2, y2)) return 0;//没有交点直接返回
int res = 0;
if(chkin(t[p].v[0], t[p].v[1], x1, y1, x2, y2)) {//当前点合法,加入答案
res++;
}
return res + query(lp, x1, y1, x2, y2) + query(rp, x1, y1, x2, y2);//递归求解
}
void ins(int p, node k) {
a[cnt = 1] = k;
for(int i = 0; i < 18; i++) {
if(!rt[p][i]) {
build(rt[p][i], 1, cnt, 0);
break;
}
else append(rt[p][i]);
}
}
int que(int p, int x1, int y1, int x2, int y2) {
int ans = 0;
for(int i = 0; i < 18; i++) {
ans += query(rt[p][i], x1, y1, x2, y2);
}
return ans;
}
}
int rt;
namespace Sgt {
struct Segment_Tree {
int l, r;
}t[Maxn * 30];
int tot = 0;
void mdf(int &p, int l, int r, int x, node k) {
if(!p) p = ++tot;
KDT::ins(p, k);
if(l == r) {
return ;
}
int mid = (l + r) >> 1;
if(x <= mid) mdf(lp, l, mid, x, k);
else mdf(rp, mid + 1, r, x, k);
}
int query(int p, int l, int r, int k, int x1, int y1, int x2, int y2) {
if(!p) return Inf;
if(l == r) {
int ret = KDT::que(p, x1, y1, x2, y2);
return k <= ret ? l : Inf;
}
int mid = (l + r) >> 1;
int res = KDT::que(rp, x1, y1, x2, y2);
if(res < k) return query(lp, l, mid, k - res, x1, y1, x2, y2);
else return query(rp, mid + 1, r, k, x1, y1, x2, y2);
}
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
cin >> n >> q;
int lst = 0;
while(q--) {
int typ, x1, y1, x2, y2, k;
cin >> typ;
switch(typ) {
case 1: {
cin >> x1 >> y1 >> k;
x1 ^= lst, y1 ^= lst, k ^= lst;
Sgt::mdf(rt, 1, N, k, (node){x1, y1});
break;
}
case 2: {
cin >> x1 >> y1 >> x2 >> y2 >> k;
x1 ^= lst, y1 ^= lst, x2 ^= lst, y2 ^= lst, k ^= lst;
lst = Sgt::query(rt, 1, N, k, x1, y1, x2, y2);
if(lst == Inf) lst = 0, cout << "NAIVE!ORZzyz.\n";
else cout << lst << '\n';
break;
}
}
}
return 0;
}