K-D Tree 学习笔记

K-D Tree 学习笔记

K-D Tree 是一种可以较高效维护高维信息的数据结构,矩形查询的时间复杂度一般是 O(n11/n) 的(我不会证),OI 中一般用到的都是 2-D Tree,也就是二维的 KDT,时间复杂度最优 O(logn),最劣 O(n)

下面以 k=2 为例介绍 KDT,更大的 k 可以很方便的从 k=2 拓展过来。

KDT

建树

先搬一张 OI-Wiki 的图过来:

假设有 A(2,3),B(4,7),C(5,4),D(7,2),E(9,6),F(8,1) 总共 6 个点。KDT 的构建方式如下:

  • 选择一个维度。
  • 选择一个切割点,将这一维度上值小于切割点的分入左子树,其余分入右子树。
  • 递归处理新分出来的两个子树。

上面的例子构建出来的 KDT 可能长这个样子:

具体说下过程:

  • 选择 x 轴作为当前维度,取点 D 作为分割点,将原来的点集划分为两个部分:[A,B,C,D,E,F][A,B,C],D,[E,F]
  • 对于 [A,B,C],选择 y 轴作为当前维度,选择点 C 作为分割点,划分为两部分:[A],C,[B]
  • 对于 [E,F],选择 y 轴作为当前维度,选择点 E 作为当前分割点,划分为两部分:E,[F]

但是这样生成出来的 KDT 的树高可能会很不平衡,所以需要人为去确定选择维度和分割点。

对于分割点的选择很好说,直接选择当前维度排序下的中位数即可。对于维度,一种比较不错的选择方法是选择各个维度中点的方差最大的作为选择的维度。这样构建出来的 KDT 树高最多为 logn 级别。

现在问题在于如何快速选出中位数。如果使用 sort,那么时间复杂度是 O(nlog2n) 的。algorithm 库中有一个函数 nth_element(begin,mid,end,cmp),作用是将 [begin, end] 中按照 cmp 规则小于 mid 放到左侧,大于的放到右侧,用只递归一半的快排实现,期望时间复杂度是 O(n) 的。

构建 KDT 的时间复杂度是 O(nlogn) 的。

插入 / 删除

如果每次插入操作都是对准一棵子树进行插入操作,那么树高将会退化成为 O(n) 的,并且由于 KDT 特殊的构造方式,使得它不能像 Treap 和 Splay 一样通过旋转保证树高。思考一下平衡树的各种实现方式,发现替罪羊树维护平衡的方式比较适合 KDT,即定义一个常数 α,如果一个节点的某个子树的大小超过了这个节点的 α,就将这个子树拍扁重构。重构方式就是上面提到的建树方式。

对于删除操作,也可以采用替罪羊树的操作方式,进行懒惰删除。如果一个子树未删除的节点占不到这棵子树的 α 就也将这个子树重构。

α 在替罪羊树上一般取 0.7,在 KDT 上取 0.6 更为合适,参考这篇博客

例题

领域查询

Luogu P7883 平面最近点对

此题就是一个用 KDT 骗分的很好的例子。

先按照题目要求建出 KDT,然后挨个枚举每一个点,查询距离最近的点,答案就是这些查询的最小值。

显然不能每一次都遍历 KDT,因为这样每次操作的时间复杂度都是满 O(n) 的,所以要加一些搜索剪枝的方法。

记录当前答案为 ans,如果当前节点代表的矩形距离距离当前点的距离已经大于了 ans,就没必要继续搜下去了。另一个剪枝就是如果两个子树的距离都是小于 ans 的,就选择距离更小的一个矩形进行搜索,这样搜索完了过后可能另一个子树的距离就大于 ans 了,就可以少去一个子树的搜索量。可以将当前点与矩形的距离理解为这种做法的估价函数。

随机数据下这种做法的时间是很优秀的,但是不难发现这种做法其实就是优化了搜索顺序的暴搜,最劣的时间复杂度仍然是 O(n) 的。

完整代码
#include<bits/stdc++.h>
using namespace std;
namespace Hanx16qwq {
constexpr int _SIZE = 4e5;
int n, ls[_SIZE + 5], rs[_SIZE + 5];
double ans = 2e18;
struct Node{
    double x, y;
}s[_SIZE + 5];
double L[_SIZE + 5], R[_SIZE + 5], U[_SIZE + 5], D[_SIZE + 5]; 
double dis(int a, int b) {
    return (s[a].x - s[b].x) * (s[a].x - s[b].x) + (s[a].y - s[b].y) * (s[a].y - s[b].y);
}
void Maintain(int x) {
    L[x] = R[x] = s[x].x;
    D[x] = U[x] = s[x].y;
    if (ls[x])
        L[x] = min(L[x], L[ls[x]]), R[x] = max(R[x], R[ls[x]]),
        D[x] = min(D[x], D[ls[x]]), U[x] = max(U[x], U[ls[x]]);
    if (rs[x])
        L[x] = min(L[x], L[rs[x]]), R[x] = max(R[x], R[rs[x]]),
        D[x] = min(D[x], D[rs[x]]), U[x] = max(U[x], U[rs[x]]);
}
int build(int l, int r) {
    if (l > r) return 0;
    if (l == r) {
        Maintain(l);
        return l;
    }
    int mid = (l + r) >> 1;
    double avx = 0, avy = 0, vax = 0, vay = 0;
    for (int i = l; i <= r; i++) avx += s[i].x, avy += s[i].y;
    avx /= r - l + 1, avy /= r - l + 1;
    for (int i = l; i <= r; i++)
        vax += (s[i].x - avx) * (s[i].x - avx),
        vay += (s[i].y - avy) * (s[i].y - avy);
    if (vax > vay)
        nth_element(s + l, s + mid, s + r + 1, [&](Node x, Node y) {
            return x.x < y.x;
        });
    else 
        nth_element(s + l, s + mid, s + r + 1, [&](Node x, Node y) {
            return x.y < y.y;
        });
    ls[mid] = build(l, mid - 1), rs[mid] = build(mid + 1, r);
    return Maintain(mid), mid;
}
double F(int a, int b) {
    double res = 0;
    if (L[b] > s[a].x) res += (L[b] - s[a].x) * (L[b] - s[a].x);
    if (R[b] < s[a].x) res += (R[b] - s[a].x) * (R[b] - s[a].x);
    if (D[b] > s[a].y) res += (D[b] - s[a].y) * (D[b] - s[a].y);
    if (U[b] < s[a].y) res += (U[b] - s[a].y) * (U[b] - s[a].y);
    return res;
}
void query(int l, int r, int x) {
    if (l > r) return;
    int mid = (l + r) >> 1;
    if (mid != x) ans = min(ans, dis(x, mid));
    if (l == r) return;
    double dist1 = F(x, ls[mid]), dist2 = F(x, rs[mid]);
    if (dist1 < ans && dist2 < ans) {
        if (dist1 < dist2) {
            query(l, mid - 1, x);
            if (dist2 < ans) query(mid + 1, r, x);
        } else {
            query(mid + 1, r, x);
            if (dist1 < ans) query(l, mid - 1, x);
        }
    } else {
        if (dist1 < ans) query(l, mid - 1, x);
        if (dist2 < ans) query(mid + 1, r, x);
    }
}
void main() {
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    cin >> n;
    for (int i = 1; i <= n; i++) cin >> s[i].x >> s[i].y;
    build(1, n);
    for (int i = 1; i <= n; i++) query(1, n, i);
    cout << fixed << setprecision(0) << ans << '\n';
}
}
signed main() {
#ifdef DEBUG
    freopen("../test.in", "r", stdin);
    freopen("../test.out", "w", stdout);
#endif
    Hanx16qwq::main();
    return 0;
}

Luogu P4357 [CQOI2016]K 远点对

与上面一道题类似,此时的估价函数应该改成距离矩形最远点的距离,然后用堆维护前 k 大的距离即可。

具体做法就是维护一个小根堆,先往堆内加入 k0,然后对于每个子树如果当前子树估价小于了堆顶元素,那么就证明这个子树不会更新堆内元素了,所以就不需要搜索这个子树。当进入一个新节点的时候,如果这个节点大于堆顶元素,就弹出堆顶,加入当前节点。其它细节与上一题很类似。

完整代码
#include<bits/stdc++.h>
#define int long long

using namespace std;

namespace Hanx16qwq {
constexpr int _SIZE = 1e5;
int n, k;

struct Node {
	int x, y;
}s[_SIZE + 5];

int ls[_SIZE + 5], rs[_SIZE + 5];
int U[_SIZE + 5], D[_SIZE + 5], L[_SIZE + 5], R[_SIZE + 5];
priority_queue<int, vector<int>, greater<int>> q;

void Maintain(int x) {
	L[x] = R[x] = s[x].x;
	U[x] = D[x] = s[x].y;

	if (ls[x])
		L[x] = min(L[x], L[ls[x]]), R[x] = max(R[x], R[ls[x]]),
		D[x] = min(D[x], D[ls[x]]), U[x] = max(U[x], U[ls[x]]);

    if (rs[x])
        L[x] = min(L[x], L[rs[x]]), R[x] = max(R[x], R[rs[x]]),
        D[x] = min(D[x], D[rs[x]]), U[x] = max(U[x], U[rs[x]]);
}

int sq(int x) {return x * x;}

int Build(int l, int r) {
    if (l > r) return 0;

    int mid = (l + r) >> 1;
    double avx = 0, avy = 0, vax = 0, vay = 0;
    
    for (int i = l; i <= r; i++)
        avx += s[i].x, avy += s[i].y;

    avx /= (r - l + 1);
    avy /= (r - l + 1);

    for (int i = l; i <= r; i++)
        vax += sq(s[i].x - avx), vay += sq(s[i].y - avy);
    
    if (vax > vay)
        nth_element(s + l, s + mid, s + r + 1, [](Node x, Node y) {
            return x.x < y.x;
        });
    else
        nth_element(s + l, s + mid, s + r + 1, [](Node x, Node y) {
            return x.y < y.y;
        });

    ls[mid] = Build(l, mid - 1), rs[mid] = Build(mid + 1, r);
    return Maintain(mid), mid;
}

int calc(int a, int b) {
    return max(sq(s[a].x - L[b]), sq(s[a].x - R[b])) + 
           max(sq(s[a].y - U[b]), sq(s[a].y - D[b]));
}

void query(int l, int r, int x) {
    if (l > r) return;

    int mid = (l + r) >> 1;
    int res = sq(s[x].x  - s[mid].x) + sq(s[x].y - s[mid].y);

    if (res > q.top()) q.pop(), q.emplace(res);

    int dist1 = calc(x, ls[mid]), dist2 = calc(x, rs[mid]);

    if (dist1 > q.top() && dist2 > q.top()) {
        if (dist1 > dist2) {
            query(l, mid - 1, x);

            if (dist2 > q.top()) query(mid + 1, r, x);
        } else {
            query(mid + 1, r, x);

            if (dist1 > q.top()) query(l, mid - 1, x);
        }
    } else {
        if (dist1 > q.top()) query(l, mid - 1, x);

        if (dist2 > q.top()) query(mid + 1, r, x);
    }
}

void main() {
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    cin >> n >> k;
    k <<= 1;

    for (int i = 1; i <= k; i++) q.emplace(0);

    for (int i = 1; i <= n; i++) cin >> s[i].x >> s[i].y;

    Build(1, n);

    for (int i = 1; i <= n; i++) query(1, n, i);

    cout << q.top() << '\n';
}
}

signed main() {
#ifdef DEBUG
    freopen("../test.in", "r", stdin);
    freopen("../test.out", "w", stdout);
#endif
    Hanx16qwq::main();
    return 0;
}

矩形查询

Luogu P4148 简单题

此题不仅卡空间,而且强制在线,就是摆明了让你写 KDT。

加入节点的方式上面有说到,不再说了。这里来说说怎么查询。

很明显,KDT 上每个节点都代表了一个矩形,如果这个矩形被查询区间完全覆盖,那么就将这个节点维护好的子树和直接贡献进入答案。如果完全无交,就没有继续搜索的必要了。否则就递归进入子树,每进入一个新节点就判断是否在查询区间内,是就贡献进入答案(有点像线段树)。

具体可以结合代码(封装的非常严实)。

完整代码
#include<bits/stdc++.h>

using namespace std;

namespace Hanx16qwq {
class KDT {
private:
    static const int _SIZE = 5e5;
    int ls[_SIZE + 5], rs[_SIZE + 5], d[_SIZE + 5];
    int L[_SIZE + 5], R[_SIZE + 5], D[_SIZE + 5], U[_SIZE + 5];
    int siz[_SIZE + 5], sum[_SIZE + 5];
    int ldt[_SIZE + 5];
    int root, cnt;
    struct Node{
        int x, y, v;
    }s[_SIZE + 5];
    const double alpha = 0.6;

    void Maintain(int x) {
        siz[x] = siz[ls[x]] + siz[rs[x]] + 1;
        sum[x] = sum[ls[x]] + sum[rs[x]] + s[x].v;
        L[x] = R[x] = s[x].x;
        D[x] = U[x] = s[x].y;

        if (ls[x])
            L[x] = min(L[x], L[ls[x]]), R[x] = max(R[x], R[ls[x]]),
            D[x] = min(D[x], D[ls[x]]), U[x] = max(U[x], U[ls[x]]);

        if (rs[x])
            L[x] = min(L[x], L[rs[x]]), R[x] = max(R[x], R[rs[x]]),
            D[x] = min(D[x], D[rs[x]]), U[x] = max(U[x], U[rs[x]]);
    }

    bool CanRbd(int x) {
        return siz[x] * alpha <= (double)max(siz[ls[x]], siz[rs[x]]);
    }

    void Flatten(int x, int &ldc) {
        if (ls[x]) Flatten(ls[x], ldc);

        ldt[++ldc] = x;

        if (rs[x]) Flatten(rs[x], ldc);
    }

    int sq(int x) {return x * x;}

    int Build(int l, int r) {
        if (l > r) return 0;

        int mid = (l + r) >> 1;
        double avx = 0, avy = 0, vax = 0, vay = 0;
        
        for (int i = l; i <= r; i++) avx += s[ldt[i]].x, avy += s[ldt[i]].y;

        avx /= r - l + 1, avy /= r - l + 1;

        for (int i = l; i <= r; i++)
            vax += sq(s[ldt[i]].x - avx), vay += sq(s[ldt[i]].y - avy);

        if (vax > vay)
            d[ldt[mid]] = 1, nth_element(ldt + l, ldt + mid, ldt + r + 1, [&](int x, int y) {
                return s[x].x < s[y].x;
            });
        else
            d[ldt[mid]] = 2, nth_element(ldt + l, ldt + mid, ldt + r + 1, [&](int x, int y) {
                return s[x].y < s[y].y;
            });

        ls[ldt[mid]] = Build(l, mid - 1), rs[ldt[mid]] = Build(mid + 1, r);
        return Maintain(ldt[mid]), ldt[mid];
    }

    void Rebuild(int &x) {
        int ldc = 0;
        Flatten(x, ldc);
        x = Build(1, ldc);
    }
    
    int NewNode(int x, int y, int v) {
        s[++cnt] = {x, y, v};
        return cnt;
    }

    void Insert(int &x, int a, int b, int w) {
        if (!x) {
            x = NewNode(a, b, w);
            return Maintain(x);
        }

        if (d[x] == 1) {
            if (a <= s[x].x) Insert(ls[x], a, b, w);
            else Insert(rs[x], a, b, w);
        } else {
            if (b <= s[x].y) Insert(ls[x], a, b, w);
            else Insert(rs[x], a, b, w);
        }

        Maintain(x);

        if (CanRbd(x)) Rebuild(x);
    }

    int Query(int x, int al, int ar, int au, int ad) {
        if (!x || R[x] < al || L[x] > ar || U[x] < ad || D[x] > au) return 0;

        if (L[x] >= al && R[x] <= ar && U[x] <= au && D[x] >= ad) return sum[x];

        int res = 0;

        if (s[x].x >= al && s[x].x <= ar && s[x].y <= au && s[x].y >= ad) res = s[x].v;

        return res + Query(ls[x], al, ar, au, ad) + Query(rs[x], al, ar, au, ad);
    }
public:
    KDT() {root = 0, cnt = 0;}

    void Insert(int a, int b, int w) {Insert(root, a, b, w);}

    int Query(int al, int ar, int au, int ad) {return Query(root, al, ar, au, ad);}
};

KDT t;
int n, last;

void main() {
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    cin >> n;
    int opt, x, y, a, b;

    while (cin >> opt, opt != 3) {
        cin >> x >> y >> a;
        x ^= last, y ^= last, a ^= last;
        
        if (opt == 1) t.Insert(x, y, a);
        else {
            cin >> b; b ^= last;
            cout << (last = t.Query(x, a, b, y)) << '\n';
        }
    }
}
}
signed main() {
#ifdef DEBUG
    freopen("../test.in", "r", stdin);
    freopen("../test.out", "w", stdout);
#endif
    Hanx16qwq::main();
    return 0;
}
posted @   Hanx16Msgr  阅读(89)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
点击右上角即可分享
微信分享提示