【算法】KD-tree
1. 算法简介
KD-tree(K-Dimensional),是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构。 主要应用于多维空间关键数据的搜索。
KD-tree 的本质是一棵平衡树,将空间内的区域划分为一个超长方体,然后存储为节点进行维护。
以下为一个 \(k=2\) 时的 KD-tree。
2. 算法理论
考虑如何维护一个二维的平面信息。
将其分为多个区域,每次递归统计该区域内的子区域信息,进行维护。但是这样的做法过于劣。
那每次取中位点进行划分呢:考虑每次取横坐标的中位数所在的点作为节点,再依次递归到左右区间。这样可以使维护的代价更有优一些。
那每次两个维度交换划分呢:考虑每次按照横坐标、纵坐标、横坐标 \(\dots\) 这样的顺序进行划分,这样每次划分的区间会越来越小,维护的代价自然比划分一维的代价更优。
所以我们便知晓了 KD-tree 的基本原理:利用 \(k\) 维的交替划分,使得每一个子矩形的信息能用一个或多个点进行维护。这样就能再保证时间复杂度等各项性能最优的情况下,维护一个二维平面的信息了。
3. 算法实现
3.1 建树
根据 2. 算法理论 所述,我们知晓了 KD-tree 的基本建树原理。
首先,我们可以定义一个结构体来存储 KD-tree 上的节点。
struct KD_tree {
int x[2], v;
int l, r, sum;
int L[2], R[2];
} t[N];
这里 \(x\) 数组存储节点坐标,\(l,r\) 表示该节点在树上的左右儿子编号,\(v, sum\) 分别表示该节点的值以及该节点所管辖的子矩形的信息。\(L,R\) 数组分别表示该节点所管辖的子矩形的左下角、右上角的坐标。
将需要建成 KD-tree 的节点编号存入 \(b\) 数组中,并按以下顺序进行建树:
- 找出当前序列 \([l,r]\) 中 \(x/y\) 坐标的中位数所对应的节点,并将其存入当前节点。
- 取区间中点 \(mid\),分治处理 \([l,mid-1]\) 与 \([mid+1,r]\)。
注意,这里需要注意区间开闭的问题,在 KD-tree 中,子区间通常分为\([l,mid-1]\) 与 \([mid+1,r]\)。所以树上的一个节点对应的便是二维平面上的某一点。
如何快速取到中位数呢?这里可以使用C++库种自带的 nth_element
函数,它能找出序列 \(a\) 中第 \(k\) 大的数,并放置在位置 \(x\) 上,同时让小于 \(x\) 的所有位置上的数都小于等于 \(a_x\),大于 \(x\) 的所有位置上的数都大于等于 \(a_x\)。
然后运用运用线段数的建树思想,按照以上步骤构建一棵 KD-tree 即可。
Code:
int build(int l, int r, int D = 0) {//D=0/1 表示当前维度为 x/y
int mid = l + r >> 1;
nth_element(b + l, b + mid, b + r + 1, [D](int x, int y){return t[x].x[D] < t[y].x[D];});//找中位数并按要求组合顺序
int x = b[mid];//存入该节点
if(l < mid) t[x].l = build(l, mid - 1, D ^ 1);//递归左儿子 [l,mid-1]
if(r > mid) t[x].r = build(mid + 1, r, D ^ 1);//递归左儿子 [mid+1,r]
pushup(x);//信息上传
return x;
}
时间复杂度 \(O(n\log n)\)。
注意: nth_element
适用范围必须在当前区间 \([l,r]\) 内。
以下以 luoguP4148 简单题 为例。
3.2 矩形操作
3.2.1 信息上传(pushup)
考虑一个节点在二维平面上所管辖的矩形。
picture
可以发现,当前点所管瞎的矩形:
- 左下角坐标为该点子树内所有点的横纵坐标最小值所组成。
- 右上角坐标为该点子树内所有点的横纵坐标最大值所组成。
Code:
void pushup(int p) {
t[p].sum = t[t[p].l].sum + t[t[p].r].sum + t[p].v;//预处理加和
for (int k : {0, 1}) {
t[p].R[k] = t[p].L[k] = t[p].x[k];
if(t[p].l) {
t[p].L[k] = min(t[p].L[k], t[t[p].l].L[k]);//更新左下角坐标
t[p].R[k] = max(t[p].R[k], t[t[p].l].R[k]);//更新右下角坐标
}
if(t[p].r) {
t[p].L[k] = min(t[p].L[k], t[t[p].r].L[k]);//更新左下角坐标
t[p].R[k] = max(t[p].R[k], t[t[p].r].R[k]);//更新右下角坐标
}
}
}
当然,除了预处理 \(L,R\),同时需要预处理 \(sum\) 等题目所需的数据。
3.2.2 矩形查询
这里要分为三种情况讨论:矩形包含,矩形有交,矩形无交。
令查询矩形左下角坐标 \((x_1,y_1)\),右上角坐标 \((x_2,y_2)\)。当前矩形左下角坐标 \((t[p].L[0],t[p].L[1])\),右上角坐标 \((t[p].R[0],t[p].R[1])\)。则有:
对于矩形包含: 判断很好办,如果 \(x_0 \le t[p].L[0]\) 且 \(t[p].R[0] \le x_1\),同时 \(y_0 \le t[p].L[1]\) 且 \(t[p].R[1] \le y_1\),则当前矩形包含查询矩阵。
直接加上当前节点维护的信息即可,退回。
对于矩形无交: 判断也很好办,如果 \(x_0 > t[p].R[0]\) 或 \(t[p].L[0] > x_1\),或者 \(y_0 > t[p].R[1]\) 或 \(t[p].L[1] > y_1\),则当前矩形与查询矩阵无交。
直接退回。
对于矩形有交: 若矩形无交,则当前矩形与查询矩阵有交。
递归到左、右子树继续查询。
当然,由于树上的一个节点对应的都是二维平面上的某一点。 所以还要判断当前节点是否在查询矩形内,即 \(x_0 \le t[p].x[0]\) 且 \(t[p].x[0] \le x_1\),同时 \(y_0 \le t[p].x[1]\) 且 \(t[p].x[1] \le y_1\) 是否成立。成立就要加上该节点的信息。
于是有了矩形查询。
Code:
int query(int p) {
if(!p) return 0;
bool f = 1; int ans = 0;
for (int k : {0, 1}) {//判断矩形包含
f &= (t[p].L[k] >= lf.x[k] && t[p].R[k] <= rh.x[k]);
}
if(f) return t[p].sum;
for (int k : {0, 1}) {//判断矩形无交
if(t[p].R[k] < lf.x[k] || t[p].L[k] > rh.x[k]) return 0;
}
f = 1;
for (int k : {0, 1}) {//判断单点覆盖
f &= (lf.x[k] <= t[p].x[k] && t[p].x[k] <= rh.x[k]);
}
if(f) ans = t[p].v;//若覆盖则加上单点信息
return ans + query(t[p].l) + query(t[p].r);//若区间有交则递归左右子树继续查询。
}
3.3 单点修改
以下以 luoguP4148 简单题 为例。
需要单点修改 \((x,y)\),则需考虑将新点插入原 KD-tree 中,但此时保证不了树的平衡。
可以考虑将树用替罪羊树的思想,若不平衡拍扁重构。但是时间复杂度过于不优秀。
这里讲一个二进制分组的做法:
考虑设 \(rt_i\) 表示大小为 \(2^i\) 的 KD-tree 的根节点编号。当前新加入一个节点时,遍历 \(rt\) 数组。将遍历途中的存过的 \(rt_i\) 取出,若遇到一个空的 \(rt_i\),将取出的节点建立一棵 KD-tree 后放入当前 \(rt_i\)。
这样下来,修改操作的时间复杂度均滩 \(O(q\log^2 n)\)。
Code:
void Add(int &p) {//取点重构
if(!p) return ;
b[++cnt] = p;
Add(t[p].l);//取出左儿子
Add(t[p].r);//取出右儿子
p = 0;//删数据
}
for (int siz = 0; ; siz++) {
if(!rt[siz]) {
rt[siz] = build(1, cnt);//建树
break;
} else {
Add(rt[siz]);//将所有 rt_siz 内的节点拉出
}
}
当然查询的是一片森林,所以要遍历所有的根,时间复杂度为 \(O(\sum_{i\geq 0}(\frac{n}{2^i})^{1-\frac{1}{k}}) = O(n^{1-\frac{1}{k}})\)
Code:
lst = 0;
For(i,0,LG) lst += query(rt[i]);
cout << lst << '\n';
4. 例题
4.1 luoguP4148 简单题
Problem
你有一个\(N \times N\)的棋盘,每个格子内有一个整数,初始时的时候全部为 \(0\),现在需要维护两种操作:
1 x y A
\(1\le x,y\le N\),\(A\) 是正整数。将格子x
,y
里的数字加上 \(A\)。2 x1 y1 x2 y2
\(1 \le x_1 \le x_2 \le N\),\(1 \le y_1\le y_2 \le N\)。输出 \(x_1, y_1, x_2, y_2\) 这个矩形内的数字和3
无 终止程序
Solve
板子题不过多讲解
Code
#include<bits/stdc++.h>
#define ll long long
#define For(i,l,r) for(int i=l;i<=r;++i)
#define FOR(i,r,l) for(int i=r;i>=l;--i)
using namespace std;
const int N = 2e5 + 10, LG = 21;
struct KD_tree {
int x[2], v;
int l, r, sum;
int L[2], R[2];
} t[N], lf, rh;
int n, rt[30], b[N], cnt;
void pushup(int p) {
t[p].sum = t[t[p].l].sum + t[t[p].r].sum + t[p].v;
for (int k : {0, 1}) {
t[p].R[k] = t[p].L[k] = t[p].x[k];
if(t[p].l) {
t[p].L[k] = min(t[p].L[k], t[t[p].l].L[k]);
t[p].R[k] = max(t[p].R[k], t[t[p].l].R[k]);
}
if(t[p].r) {
t[p].L[k] = min(t[p].L[k], t[t[p].r].L[k]);
t[p].R[k] = max(t[p].R[k], t[t[p].r].R[k]);
}
}
}
int build(int l, int r, int D = 0) {
int mid = l + r >> 1;
nth_element(b + l, b + mid, b + r + 1, [D](int x, int y){return t[x].x[D] < t[y].x[D];});
int x = b[mid];
if(l < mid) t[x].l = build(l, mid - 1, D ^ 1);
if(r > mid) t[x].r = build(mid + 1, r, D ^ 1);
pushup(x);
return x;
}
int query(int p) {
if(!p) return 0;
bool f = 1; int ans = 0;
for (int k : {0, 1}) {
f &= (t[p].L[k] >= lf.x[k] && t[p].R[k] <= rh.x[k]);
}
if(f) return t[p].sum;
for (int k : {0, 1}) {
if(t[p].R[k] < lf.x[k] || t[p].L[k] > rh.x[k]) return 0;
}
f = 1;
for (int k : {0, 1}) {
f &= (lf.x[k] <= t[p].x[k] && t[p].x[k] <= rh.x[k]);
}
if(f) ans = t[p].v;
return ans + query(t[p].l) + query(t[p].r);
}
void Add(int &p) {
if(!p) return ;
b[++cnt] = p;
Add(t[p].l);
Add(t[p].r);
p = 0;
}
signed main() {
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
cin >> n;
n = 0; int lst = 0;
while(1) {
int op; cin >> op;
if(op == 3) break;
else if(op == 1) {
int x, y, w; cin >> x >> y >> w;
x ^= lst, y ^= lst, w ^= lst;
t[++n] = {{x, y}, w};
b[cnt=1] = n;
for (int siz = 0; ; siz++) {
if(!rt[siz]) {
rt[siz] = build(1, cnt);
break;
} else {
Add(rt[siz]);
}
}
} else {
cin >> lf.x[0] >> lf.x[1] >> rh.x[0] >> rh.x[1];
lf.x[0] ^= lst;
lf.x[1] ^= lst;
rh.x[0] ^= lst;
rh.x[1] ^= lst;
lst = 0;
For(i,0,LG) lst += query(rt[i]);
cout << lst << '\n';
}
}
return 0;
}