K-D Tree 学习笔记
K-D Tree 学习笔记
K-D Tree 是一种可以较高效维护高维信息的数据结构,矩形查询的时间复杂度一般是
下面以
KDT
建树
先搬一张 OI-Wiki 的图过来:
假设有
- 选择一个维度。
- 选择一个切割点,将这一维度上值小于切割点的分入左子树,其余分入右子树。
- 递归处理新分出来的两个子树。
上面的例子构建出来的 KDT 可能长这个样子:
具体说下过程:
- 选择
轴作为当前维度,取点 作为分割点,将原来的点集划分为两个部分: 。 - 对于
,选择 轴作为当前维度,选择点 作为分割点,划分为两部分: 。 - 对于
,选择 轴作为当前维度,选择点 作为当前分割点,划分为两部分: 。
但是这样生成出来的 KDT 的树高可能会很不平衡,所以需要人为去确定选择维度和分割点。
对于分割点的选择很好说,直接选择当前维度排序下的中位数即可。对于维度,一种比较不错的选择方法是选择各个维度中点的方差最大的作为选择的维度。这样构建出来的 KDT 树高最多为
现在问题在于如何快速选出中位数。如果使用 sort
,那么时间复杂度是 algorithm
库中有一个函数 nth_element(begin,mid,end,cmp)
,作用是将 [begin, end] 中按照 cmp
规则小于 mid
放到左侧,大于的放到右侧,用只递归一半的快排实现,期望时间复杂度是
构建 KDT 的时间复杂度是
插入 / 删除
如果每次插入操作都是对准一棵子树进行插入操作,那么树高将会退化成为
对于删除操作,也可以采用替罪羊树的操作方式,进行懒惰删除。如果一个子树未删除的节点占不到这棵子树的
例题
领域查询
Luogu P7883 平面最近点对
此题就是一个用 KDT 骗分的很好的例子。
先按照题目要求建出 KDT,然后挨个枚举每一个点,查询距离最近的点,答案就是这些查询的最小值。
显然不能每一次都遍历 KDT,因为这样每次操作的时间复杂度都是满
记录当前答案为
随机数据下这种做法的时间是很优秀的,但是不难发现这种做法其实就是优化了搜索顺序的暴搜,最劣的时间复杂度仍然是
完整代码
#include<bits/stdc++.h>
using namespace std;
namespace Hanx16qwq {
constexpr int _SIZE = 4e5;
int n, ls[_SIZE + 5], rs[_SIZE + 5];
double ans = 2e18;
struct Node{
double x, y;
}s[_SIZE + 5];
double L[_SIZE + 5], R[_SIZE + 5], U[_SIZE + 5], D[_SIZE + 5];
double dis(int a, int b) {
return (s[a].x - s[b].x) * (s[a].x - s[b].x) + (s[a].y - s[b].y) * (s[a].y - s[b].y);
}
void Maintain(int x) {
L[x] = R[x] = s[x].x;
D[x] = U[x] = s[x].y;
if (ls[x])
L[x] = min(L[x], L[ls[x]]), R[x] = max(R[x], R[ls[x]]),
D[x] = min(D[x], D[ls[x]]), U[x] = max(U[x], U[ls[x]]);
if (rs[x])
L[x] = min(L[x], L[rs[x]]), R[x] = max(R[x], R[rs[x]]),
D[x] = min(D[x], D[rs[x]]), U[x] = max(U[x], U[rs[x]]);
}
int build(int l, int r) {
if (l > r) return 0;
if (l == r) {
Maintain(l);
return l;
}
int mid = (l + r) >> 1;
double avx = 0, avy = 0, vax = 0, vay = 0;
for (int i = l; i <= r; i++) avx += s[i].x, avy += s[i].y;
avx /= r - l + 1, avy /= r - l + 1;
for (int i = l; i <= r; i++)
vax += (s[i].x - avx) * (s[i].x - avx),
vay += (s[i].y - avy) * (s[i].y - avy);
if (vax > vay)
nth_element(s + l, s + mid, s + r + 1, [&](Node x, Node y) {
return x.x < y.x;
});
else
nth_element(s + l, s + mid, s + r + 1, [&](Node x, Node y) {
return x.y < y.y;
});
ls[mid] = build(l, mid - 1), rs[mid] = build(mid + 1, r);
return Maintain(mid), mid;
}
double F(int a, int b) {
double res = 0;
if (L[b] > s[a].x) res += (L[b] - s[a].x) * (L[b] - s[a].x);
if (R[b] < s[a].x) res += (R[b] - s[a].x) * (R[b] - s[a].x);
if (D[b] > s[a].y) res += (D[b] - s[a].y) * (D[b] - s[a].y);
if (U[b] < s[a].y) res += (U[b] - s[a].y) * (U[b] - s[a].y);
return res;
}
void query(int l, int r, int x) {
if (l > r) return;
int mid = (l + r) >> 1;
if (mid != x) ans = min(ans, dis(x, mid));
if (l == r) return;
double dist1 = F(x, ls[mid]), dist2 = F(x, rs[mid]);
if (dist1 < ans && dist2 < ans) {
if (dist1 < dist2) {
query(l, mid - 1, x);
if (dist2 < ans) query(mid + 1, r, x);
} else {
query(mid + 1, r, x);
if (dist1 < ans) query(l, mid - 1, x);
}
} else {
if (dist1 < ans) query(l, mid - 1, x);
if (dist2 < ans) query(mid + 1, r, x);
}
}
void main() {
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin >> n;
for (int i = 1; i <= n; i++) cin >> s[i].x >> s[i].y;
build(1, n);
for (int i = 1; i <= n; i++) query(1, n, i);
cout << fixed << setprecision(0) << ans << '\n';
}
}
signed main() {
#ifdef DEBUG
freopen("../test.in", "r", stdin);
freopen("../test.out", "w", stdout);
#endif
Hanx16qwq::main();
return 0;
}
Luogu P4357 [CQOI2016]K 远点对
与上面一道题类似,此时的估价函数应该改成距离矩形最远点的距离,然后用堆维护前
具体做法就是维护一个小根堆,先往堆内加入
完整代码
#include<bits/stdc++.h>
#define int long long
using namespace std;
namespace Hanx16qwq {
constexpr int _SIZE = 1e5;
int n, k;
struct Node {
int x, y;
}s[_SIZE + 5];
int ls[_SIZE + 5], rs[_SIZE + 5];
int U[_SIZE + 5], D[_SIZE + 5], L[_SIZE + 5], R[_SIZE + 5];
priority_queue<int, vector<int>, greater<int>> q;
void Maintain(int x) {
L[x] = R[x] = s[x].x;
U[x] = D[x] = s[x].y;
if (ls[x])
L[x] = min(L[x], L[ls[x]]), R[x] = max(R[x], R[ls[x]]),
D[x] = min(D[x], D[ls[x]]), U[x] = max(U[x], U[ls[x]]);
if (rs[x])
L[x] = min(L[x], L[rs[x]]), R[x] = max(R[x], R[rs[x]]),
D[x] = min(D[x], D[rs[x]]), U[x] = max(U[x], U[rs[x]]);
}
int sq(int x) {return x * x;}
int Build(int l, int r) {
if (l > r) return 0;
int mid = (l + r) >> 1;
double avx = 0, avy = 0, vax = 0, vay = 0;
for (int i = l; i <= r; i++)
avx += s[i].x, avy += s[i].y;
avx /= (r - l + 1);
avy /= (r - l + 1);
for (int i = l; i <= r; i++)
vax += sq(s[i].x - avx), vay += sq(s[i].y - avy);
if (vax > vay)
nth_element(s + l, s + mid, s + r + 1, [](Node x, Node y) {
return x.x < y.x;
});
else
nth_element(s + l, s + mid, s + r + 1, [](Node x, Node y) {
return x.y < y.y;
});
ls[mid] = Build(l, mid - 1), rs[mid] = Build(mid + 1, r);
return Maintain(mid), mid;
}
int calc(int a, int b) {
return max(sq(s[a].x - L[b]), sq(s[a].x - R[b])) +
max(sq(s[a].y - U[b]), sq(s[a].y - D[b]));
}
void query(int l, int r, int x) {
if (l > r) return;
int mid = (l + r) >> 1;
int res = sq(s[x].x - s[mid].x) + sq(s[x].y - s[mid].y);
if (res > q.top()) q.pop(), q.emplace(res);
int dist1 = calc(x, ls[mid]), dist2 = calc(x, rs[mid]);
if (dist1 > q.top() && dist2 > q.top()) {
if (dist1 > dist2) {
query(l, mid - 1, x);
if (dist2 > q.top()) query(mid + 1, r, x);
} else {
query(mid + 1, r, x);
if (dist1 > q.top()) query(l, mid - 1, x);
}
} else {
if (dist1 > q.top()) query(l, mid - 1, x);
if (dist2 > q.top()) query(mid + 1, r, x);
}
}
void main() {
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin >> n >> k;
k <<= 1;
for (int i = 1; i <= k; i++) q.emplace(0);
for (int i = 1; i <= n; i++) cin >> s[i].x >> s[i].y;
Build(1, n);
for (int i = 1; i <= n; i++) query(1, n, i);
cout << q.top() << '\n';
}
}
signed main() {
#ifdef DEBUG
freopen("../test.in", "r", stdin);
freopen("../test.out", "w", stdout);
#endif
Hanx16qwq::main();
return 0;
}
矩形查询
Luogu P4148 简单题
此题不仅卡空间,而且强制在线,就是摆明了让你写 KDT。
加入节点的方式上面有说到,不再说了。这里来说说怎么查询。
很明显,KDT 上每个节点都代表了一个矩形,如果这个矩形被查询区间完全覆盖,那么就将这个节点维护好的子树和直接贡献进入答案。如果完全无交,就没有继续搜索的必要了。否则就递归进入子树,每进入一个新节点就判断是否在查询区间内,是就贡献进入答案(有点像线段树)。
具体可以结合代码(封装的非常严实)。
完整代码
#include<bits/stdc++.h>
using namespace std;
namespace Hanx16qwq {
class KDT {
private:
static const int _SIZE = 5e5;
int ls[_SIZE + 5], rs[_SIZE + 5], d[_SIZE + 5];
int L[_SIZE + 5], R[_SIZE + 5], D[_SIZE + 5], U[_SIZE + 5];
int siz[_SIZE + 5], sum[_SIZE + 5];
int ldt[_SIZE + 5];
int root, cnt;
struct Node{
int x, y, v;
}s[_SIZE + 5];
const double alpha = 0.6;
void Maintain(int x) {
siz[x] = siz[ls[x]] + siz[rs[x]] + 1;
sum[x] = sum[ls[x]] + sum[rs[x]] + s[x].v;
L[x] = R[x] = s[x].x;
D[x] = U[x] = s[x].y;
if (ls[x])
L[x] = min(L[x], L[ls[x]]), R[x] = max(R[x], R[ls[x]]),
D[x] = min(D[x], D[ls[x]]), U[x] = max(U[x], U[ls[x]]);
if (rs[x])
L[x] = min(L[x], L[rs[x]]), R[x] = max(R[x], R[rs[x]]),
D[x] = min(D[x], D[rs[x]]), U[x] = max(U[x], U[rs[x]]);
}
bool CanRbd(int x) {
return siz[x] * alpha <= (double)max(siz[ls[x]], siz[rs[x]]);
}
void Flatten(int x, int &ldc) {
if (ls[x]) Flatten(ls[x], ldc);
ldt[++ldc] = x;
if (rs[x]) Flatten(rs[x], ldc);
}
int sq(int x) {return x * x;}
int Build(int l, int r) {
if (l > r) return 0;
int mid = (l + r) >> 1;
double avx = 0, avy = 0, vax = 0, vay = 0;
for (int i = l; i <= r; i++) avx += s[ldt[i]].x, avy += s[ldt[i]].y;
avx /= r - l + 1, avy /= r - l + 1;
for (int i = l; i <= r; i++)
vax += sq(s[ldt[i]].x - avx), vay += sq(s[ldt[i]].y - avy);
if (vax > vay)
d[ldt[mid]] = 1, nth_element(ldt + l, ldt + mid, ldt + r + 1, [&](int x, int y) {
return s[x].x < s[y].x;
});
else
d[ldt[mid]] = 2, nth_element(ldt + l, ldt + mid, ldt + r + 1, [&](int x, int y) {
return s[x].y < s[y].y;
});
ls[ldt[mid]] = Build(l, mid - 1), rs[ldt[mid]] = Build(mid + 1, r);
return Maintain(ldt[mid]), ldt[mid];
}
void Rebuild(int &x) {
int ldc = 0;
Flatten(x, ldc);
x = Build(1, ldc);
}
int NewNode(int x, int y, int v) {
s[++cnt] = {x, y, v};
return cnt;
}
void Insert(int &x, int a, int b, int w) {
if (!x) {
x = NewNode(a, b, w);
return Maintain(x);
}
if (d[x] == 1) {
if (a <= s[x].x) Insert(ls[x], a, b, w);
else Insert(rs[x], a, b, w);
} else {
if (b <= s[x].y) Insert(ls[x], a, b, w);
else Insert(rs[x], a, b, w);
}
Maintain(x);
if (CanRbd(x)) Rebuild(x);
}
int Query(int x, int al, int ar, int au, int ad) {
if (!x || R[x] < al || L[x] > ar || U[x] < ad || D[x] > au) return 0;
if (L[x] >= al && R[x] <= ar && U[x] <= au && D[x] >= ad) return sum[x];
int res = 0;
if (s[x].x >= al && s[x].x <= ar && s[x].y <= au && s[x].y >= ad) res = s[x].v;
return res + Query(ls[x], al, ar, au, ad) + Query(rs[x], al, ar, au, ad);
}
public:
KDT() {root = 0, cnt = 0;}
void Insert(int a, int b, int w) {Insert(root, a, b, w);}
int Query(int al, int ar, int au, int ad) {return Query(root, al, ar, au, ad);}
};
KDT t;
int n, last;
void main() {
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin >> n;
int opt, x, y, a, b;
while (cin >> opt, opt != 3) {
cin >> x >> y >> a;
x ^= last, y ^= last, a ^= last;
if (opt == 1) t.Insert(x, y, a);
else {
cin >> b; b ^= last;
cout << (last = t.Query(x, a, b, y)) << '\n';
}
}
}
}
signed main() {
#ifdef DEBUG
freopen("../test.in", "r", stdin);
freopen("../test.out", "w", stdout);
#endif
Hanx16qwq::main();
return 0;
}
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了