KDT 从入门到夺门而出

简介

首先要知道 \(KD-Tree\) 是干什么的,它最广泛的用法便是维护 \(k\) 维最近点对(大部分时候是二维)。

先来讲没有插入,直接建树的。

它的每个结点维护这样子的数据,其中 \(lc\)\(rc\) 代表左右儿子,\(v[i]\) 代表第 \(i\) 维当前点的取值,\(L[i]\)\(U[i]\) 分别代表第 \(i\) 维上当前点对应的子树中所有点的范围,它其实对应了一个矩形

比如一棵子树中根为 \((114514,114514)\),左儿子为 \((1919, 1919)\) 右儿子为 \((810,810)\),那么对于根来说 \(L[1] = 810, R[1] = 114514\)

struct KD {
    int lc, rc;
    double v[2], L[2], U[2];
    bool operator < (const KD &t) const {
        return v[K] < t.v[K];
    }
} tr[N]; 

然后是建树过程,其实和替罪羊树的重构过程很像,都是拍扁再拎起来。具体地我们使用 nth_element 函数,将一排数根据 \(mid\) 分成两半。同时,为了保证复杂度,我们要轮流对第 \(k\) 维排序,即 \(0,1\) 循环,上文的比较函数也是为了这个所定义的。值得注意的是 \(k\) 是函数内的 \(K\) 是一个全局变量。

pushup 操作也比较简朴,看看就懂了,就是用儿子更新父亲。

void pushup(int u) {
    rep(i, 0, 1) {
        tr[u].L[i] = tr[u].U[i] = tr[u].v[i];
        if (tr[u].lc) {
            tr[u].L[i] = min(tr[u].L[i], tr[tr[u].lc].L[i]);
            tr[u].U[i] = max(tr[u].U[i], tr[tr[u].lc].U[i]);
        } 
        if (tr[u].rc) {
            tr[u].L[i] = min(tr[u].L[i], tr[tr[u].rc].L[i]);
            tr[u].U[i] = max(tr[u].U[i], tr[tr[u].rc].U[i]);
        }
    }
}
int build(int l, int r, int k) {
    if (l > r) return 0;
    int mid = l + r >> 1;
    K = k; nth_element(tr + l, tr + mid, tr + r);
    tr[mid].lc = build(l, mid - 1, k ^ 1);
    tr[mid].rc = build(mid + 1, r, k ^ 1);
    pushup(mid);
    return mid;
}

查询也不算非常困难,我们从根开始,分别计算要查的结点 \(cur\) 到根的距离,以及到左右儿子所在范围的最近距离。由于上文的 \(L\)\(U\) 在两维状态下可以看做是一个矩形,所以相当于一个点到矩形的最短距离。可以看看代码画图理解一下。

回到 query 中,我们得知 \(dist\) 后,可以贪心地在左右儿子中选择 \(dist\) 小的来优先更新,然后在考虑另一侧。同时左右边的最优答案一定得小于当前全局最优值,否则不用更新。

inline double sq(double x) {
    return x * x;
}
inline double dis1(int x) {
    double res = 0;
    rep(i, 0, 1) res += sq(tr[cur].v[i] - tr[x].v[i]);
    return res;
}
inline double dis2(int x) {
    if (!x) return 2e18;
    double res = 0;
    rep(i, 0, 1) res += sq(max(0.0, tr[cur].v[i] - tr[x].U[i])) + sq(max(0.0, tr[x].L[i] - tr[cur].v[i]));
    return res;
}
void query(int u) {
    if (!u) return;
    if (u != cur) ans = min(ans, dis1(u));
    double d1 = dis2(tr[u].lc), d2 = dis2(tr[u].rc);
    if (d1 < d2) {
        if (d1 < ans) query(tr[u].lc);
        if (d2 < ans) query(tr[u].rc);
    } else {
        if (d2 < ans) query(tr[u].rc);
        if (d1 < ans) query(tr[u].lc);
    }
}

经过证明(我不会),在处理二维时的复杂度是根号的,\(build\)\(O(nlogn)\)\(k\) 维是 \(O(n^{1 - \frac{1}{k}})\) 的。

当然你也可以动态差点不 \(build\),同样类似于替罪羊数当 \(A * sz[root] \geq max(sz[lc], sz[rc])\) 时就直接重构。\(A\) 我一般取 \(0.7\)

bool cmp(int a, int b) {
    return tr[a].v[K] < tr[b].v[K];
}
int rebuild(int l, int r, int k) {
    if (l > r) return 0;
    int mid = l + r >> 1;
    K = k; nth_element(g + l, g + mid, g + r + 1, cmp);
    tr[g[mid]].lc = rebuild(l, mid - 1, k ^ 1);
    tr[g[mid]].rc = rebuild(mid + 1, r, k ^ 1);
    pushup(g[mid]);
    return g[mid];
}
void dfs(int u) {
    if (!u) return;
    g[++ cnt] = u;
    dfs(tr[u].lc);
    dfs(tr[u].rc);
}
void check(int &u, int k) {
    if (tr[u].sz * A < max(tr[tr[u].lc].sz, tr[tr[u].rc].sz)) 
        cnt = 0, dfs(u), u = rebuild(1, cnt, k);
}
void insert(int &u, int k) {
    if (!u) { u = cur; pushup(u); return; }
    insert(tr[cur].v[k] <= tr[u].v[k] ? tr[u].lc : tr[u].rc, k ^ 1);
    pushup(u);
    check(u, k);
}

模板

P1429 平面最近点对(加强版)为例,贴个板子。

#include <bits/stdc++.h>
#define rep(i, a, b) for (int i = (a); i <= (b); i ++)
#define fro(i, a, b) for (int i = (a); i >= b; i --)
#define INF 0x3f3f3f3f
#define eps 1e-6
#define lowbit(x) (x & (-x))
#define initrand srand((unsigned)time(0))
#define random(x) ((LL)rand() * rand() % (x))
#define eb emplace_back
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
typedef pair<double, int> PDI;
inline int read() {
    int x = 0, f = 1;
    char ch = getchar();
    while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
    while (ch >= '0' && ch <= '9') { x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar(); }
    return x * f;
}

const int N = 200010;
int n, K, cur;
double ans = 2e18;

struct KD {
    int lc, rc;
    double v[2], L[2], U[2];
    bool operator < (const KD &t) const {
        return v[K] < t.v[K];
    }
} tr[N]; 

void pushup(int u) {
    rep(i, 0, 1) {
        tr[u].L[i] = tr[u].U[i] = tr[u].v[i];
        if (tr[u].lc) {
            tr[u].L[i] = min(tr[u].L[i], tr[tr[u].lc].L[i]);
            tr[u].U[i] = max(tr[u].U[i], tr[tr[u].lc].U[i]);
        } 
        if (tr[u].rc) {
            tr[u].L[i] = min(tr[u].L[i], tr[tr[u].rc].L[i]);
            tr[u].U[i] = max(tr[u].U[i], tr[tr[u].rc].U[i]);
        }
    }
}

int build(int l, int r, int k) {
    if (l > r) return 0;
    int mid = l + r >> 1;
    K = k; nth_element(tr + l, tr + mid, tr + r + 1);
    tr[mid].lc = build(l, mid - 1, k ^ 1);
    tr[mid].rc = build(mid + 1, r, k ^ 1);
    pushup(mid);
    return mid;
}

inline double sq(double x) {
    return x * x;
}

inline double dis1(int x) {
    double res = 0;
    rep(i, 0, 1) res += sq(tr[cur].v[i] - tr[x].v[i]);
    return res;
}

inline double dis2(int x) {
    if (!x) return 2e18;
    double res = 0;
    rep(i, 0, 1) res += sq(max(0.0, tr[cur].v[i] - tr[x].U[i])) + sq(max(0.0, tr[x].L[i] - tr[cur].v[i]));
    return res;
}

void query(int u) {
    if (!u) return;
    if (u != cur) ans = min(ans, dis1(u));
    double d1 = dis2(tr[u].lc), d2 = dis2(tr[u].rc);
    if (d1 < d2) {
        if (d1 < ans) query(tr[u].lc);
        if (d2 < ans) query(tr[u].rc);
    } else {
        if (d2 < ans) query(tr[u].rc);
        if (d1 < ans) query(tr[u].lc);
    }
}

int main() {
    n = read();
    rep(i, 1, n) scanf("%lf%lf", &tr[i].v[0], &tr[i].v[1]);
    int root = build(1, n, 0);
    for (cur = 1; cur <= n; cur ++) query(root); 
    printf("%.4lf\n", sqrt(ans));
    return 0;
}
#include <bits/stdc++.h>
#define rep(i, a, b) for (int i = (a); i <= (b); i ++)
#define fro(i, a, b) for (int i = (a); i >= b; i --)
#define INF 0x3f3f3f3f
#define eps 1e-6
#define lowbit(x) (x & (-x))
#define initrand srand((unsigned)time(0))
#define random(x) ((LL)rand() * rand() % (x))
#define eb emplace_back
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
typedef pair<double, int> PDI;
inline int read() {
    int x = 0, f = 1;
    char ch = getchar();
    while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
    while (ch >= '0' && ch <= '9') { x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar(); }
    return x * f;
}

const int N = 200010;
const double A = 0.7;
int n, K, root, cur;
double ans = 2e18;

int idx;
struct KD {
    int lc, rc, sz;
    double v[2], L[2], U[2];
    bool operator < (const KD &t) const {
        return v[K] < t.v[K];
    }
} tr[N]; 

void pushup(int u) {
    tr[u].sz = tr[tr[u].lc].sz + tr[tr[u].rc].sz;
    rep(i, 0, 1) {
        tr[u].L[i] = tr[u].U[i] = tr[u].v[i];
        if (tr[u].lc) {
            tr[u].L[i] = min(tr[u].L[i], tr[tr[u].lc].L[i]);
            tr[u].U[i] = max(tr[u].U[i], tr[tr[u].lc].U[i]);
        } 
        if (tr[u].rc) {
            tr[u].L[i] = min(tr[u].L[i], tr[tr[u].rc].L[i]);
            tr[u].U[i] = max(tr[u].U[i], tr[tr[u].rc].U[i]);
        }
    }
}

int g[N], cnt;

bool cmp(int a, int b) {
    return tr[a].v[K] < tr[b].v[K];
}

int rebuild(int l, int r, int k) {
    if (l > r) return 0;
    int mid = l + r >> 1;
    K = k; nth_element(g + l, g + mid, g + r + 1, cmp);
    tr[g[mid]].lc = rebuild(l, mid - 1, k ^ 1);
    tr[g[mid]].rc = rebuild(mid + 1, r, k ^ 1);
    pushup(g[mid]);
    return g[mid];
}

void dfs(int u) {
    if (!u) return;
    g[++ cnt] = u;
    dfs(tr[u].lc);
    dfs(tr[u].rc);
}

void check(int &u, int k) {
    if (tr[u].sz * A < max(tr[tr[u].lc].sz, tr[tr[u].rc].sz)) 
        cnt = 0, dfs(u), u = rebuild(1, cnt, k);
}

void insert(int &u, int k) {
    if (!u) { u = cur; pushup(u); return; }
    insert(tr[cur].v[k] <= tr[u].v[k] ? tr[u].lc : tr[u].rc, k ^ 1);
    pushup(u);
    check(u, k);
}

inline double sq(double x) {
    return x * x;
}

inline double dis1(int x) {
    double res = 0;
    rep(i, 0, 1) res += sq(tr[cur].v[i] - tr[x].v[i]);
    return res;
}

inline double dis2(int x) {
    if (!x) return 2e18;
    double res = 0;
    rep(i, 0, 1) res += sq(max(0.0, tr[cur].v[i] - tr[x].U[i])) + sq(max(0.0, tr[x].L[i] - tr[cur].v[i]));
    return res;
}

void query(int u) {
    if (!u) return;
    if (u != cur) ans = min(ans, dis1(u));
    double d1 = dis2(tr[u].lc), d2 = dis2(tr[u].rc);
    if (d1 < d2) {
        if (d1 < ans) query(tr[u].lc);
        if (d2 < ans) query(tr[u].rc);
    } else {
        if (d2 < ans) query(tr[u].rc);
        if (d1 < ans) query(tr[u].lc);
    }
}

int main() {
    n = read();
    rep(i, 1, n) scanf("%lf%lf", &tr[i].v[0], &tr[i].v[1]);
    for (cur = 1; cur <= n; cur ++) insert(root, 0);
    for (cur = 1; cur <= n; cur ++) query(root); 
    printf("%.4lf\n", sqrt(ans));
    return 0;
}

模板题

P2479 [SDOI2010] 捉迷藏

距离计算更改为曼哈顿距离,然后再增加统计一下最大值即可。

直接建树会比动态插入快非常多。

代码

P4148 简单题

操作 \(1\) 可以看作动态插点,操作 \(2\) 可以直接用类似线段树查询的方式,对于每个结点分类讨论三种(我们将一个 \(KDT\) 结点表示范围看作矩形):

  1. 该节点所对矩形完全不包含于询问矩形
  2. 该节点所对矩形完全包含于询问矩形
  3. 部分包含

对于 \(1\)\(2\) 来说是简单的,对于 \(3\),我们可以判断一下当前节点的 \(v\) 是否在矩形内,然后递归左右子树最后加上根的贡献。
时间复杂度被证明是根号的。

代码

困难一点的题

P5471 [NOI2019] 弹跳

考虑 \(KDT\) 优化建图后跑 \(dijkstra\),我们把 \(1\sim n\) 记为原始的点(下文称为实点),\(n + 1\sim 2n\) 记为 \(KDT\) 建出来的点(虚点)。显然一个实点 \(u\) 对应虚点 \(u + n\)。对于一个虚点 \(u\),它显然可以向 \(u - n\) 连边,也可以向它在 \(KDT\) 中的左右儿子连边;对于一个实点,我们遍历从该点出发的弹跳装置,分类讨论(以下点均为实点所对虚点):

  1. 一个虚点 \(x\) 对应的矩阵完全包含在弹跳装置内,直接将 \(u\)\(x\) 连边
  2. 完全不包含,直接返回
  3. 部分包含的话,如果该虚点的坐标能够包含于弹跳装置内就让 \(u\)\(x - n\) 连边

然而如果真的连边的话会被卡爆,我们考虑边做 \(dij\) 边跑上面过程,每次不建边直接用需要的点进行更新操作。

于是这道题就做完了,代码稍微有点难写。

代码

posted @   比翼の鼠  阅读(11)  评论(0编辑  收藏  举报
//雪花飘落效果
评论
收藏
关注
推荐
深色
回顶
收起
点击右上角即可分享
微信分享提示