KD-Tree学习笔记

思想:
每个节点对应一个矩形区域。
把一个平面分成两半,作为左右子树。
把查询点的问题转化成二叉树上的查询问题。

维持树的平衡性:
建树时,按照方差较大的维分割,选取中位数分割。
若要动态维护的点,当某个节点过重时,遍历所有子节点重新建树。

P1429 平面最近点对(加强版)

#include <bits/stdc++.h>
using namespace std;

const int N = 2e5 + 10;
int n;
double ans = 2e18;

struct kd_tree {
    double l, r, u, d;
    int lc, rc;
} t[N];

struct point {
    double x, y;
    double dis2(const point &p) const {
        return (x - p.x) * (x - p.x) + (y - p.y) * (y - p.y);
    }
} p[N];



void pushup(int x) {
    t[x].l = t[x].r = p[x].x; t[x].u = t[x].d = p[x].y;
    auto extend = [](int a, int b) {
        t[a].l = min(t[a].l, t[b].l);
        t[a].r = max(t[a].r, t[b].r);
        t[a].u = max(t[a].u, t[b].u);
        t[a].d = min(t[a].d, t[b].d);
    };
    if (t[x].lc) extend(x, t[x].lc);
    if (t[x].rc) extend(x, t[x].rc);
}

int build(int l, int r) {
    if (l > r) return 0;
    if (l == r) {
        t[l].l = t[l].r = p[l].x; t[l].u = t[l].d = p[l].y;
        return l;
    }
    //1.选择方差最大的维度
    //2.选择中位数进行分割
    int mid = (l + r) >> 1;
    double vx = 0, vy = 0, sx = 0, sy = 0;
    for (int i = l; i <= r; i++) vx += p[i].x, vy += p[i].y;
    vx /= 1.0 * (r - l + 1); vy /= 1.0 * (r - l + 1);
    for (int i = l; i <= r; i++)
        sx += (p[i].x - vx) * (p[i].x - vx), sy += (p[i].y - vy) * (p[i].y - vy);
    if (sx >= sy) nth_element(p + l, p + mid, p + r + 1, [](point a, point b) {return a.x < b.x;} );
    else nth_element(p + l, p + mid, p + r + 1, [](point a, point b) {return a.y < b.y;} );
    t[mid].lc = build(l, mid - 1); t[mid].rc = build(mid + 1, r);
    pushup(mid);
    return mid;
}



void query(int l, int r, int x) {
    if (l > r) return;
    int mid = (l + r) >> 1;
    if (mid != x) ans = min(ans, p[mid].dis2(p[x]));
    if (l == r) return;
    //求点到矩形的最短距离平方
    auto f = [](point &p, kd_tree &q) {
        double l = q.l, r = q.r, u = q.u, d = q.d, x = p.x, y = p.y;
        double res = 0;
        if (l > x) res += (l - x) * (l - x);
        if (r < x) res += (x - r) * (x - r);
        if (d > y) res += (d - y) * (d - y);
        if (u < y) res += (y - u) * (y - u);
        return res;
    };
    double disl = f(p[x], t[t[mid].lc]), disr = f(p[x], t[t[mid].rc]);
    //启发式查询
    if (disl < disr) {
        if (disl < ans) query(l, mid - 1, x);
        if (disr < ans) query(mid + 1, r, x);
    } else {
        if (disr < ans) query(mid + 1, r, x);
        if (disl < ans) query(l, mid - 1, x);
    }
}

int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) scanf("%lf%lf", &p[i].x, &p[i].y);
    build(1, n);
    for (int i = 1; i <= n; i++) query(1, n, i);
    printf("%.4lf\n", sqrt(ans));
    return 0;
}

P4475 巧克力王国

//
// Created by blackbird on 2023/3/16.
//
#include <bits/stdc++.h>
#define int long long
using namespace std;

const int N = 2e5 + 10;
int n, m, a, b, c;
double ans = 2e18;

struct kd_tree {
    int l, r, u, d;
    int lc, rc;
    int sum;
} t[N];

struct point {
    int x, y;
    int val;
    double dis2(const point &p) const {
        return (x - p.x) * (x - p.x) + (y - p.y) * (y - p.y);
    }
} p[N];



void pushup(int x) {
    t[x].l = t[x].r = p[x].x;
    t[x].u = t[x].d = p[x].y;
    t[x].sum = p[x].val;
    auto extend = [](int a, int b) {
        t[a].l = min(t[a].l, t[b].l);
        t[a].r = max(t[a].r, t[b].r);
        t[a].u = max(t[a].u, t[b].u);
        t[a].d = min(t[a].d, t[b].d);
        t[a].sum += t[b].sum;
    };
    if (t[x].lc) extend(x, t[x].lc);
    if (t[x].rc) extend(x, t[x].rc);
}

int build(int l, int r) {
    if (l > r) return 0;
    if (l == r) {
        t[l].l = t[l].r = p[l].x;
        t[l].u = t[l].d = p[l].y;
        t[l].sum = p[l].val;
        return l;
    }
    //1.选择方差最大的维度
    //2.选择中位数进行分割
    int mid = (l + r) >> 1;
    double vx = 0, vy = 0, sx = 0, sy = 0;
    for (int i = l; i <= r; i++) vx += p[i].x, vy += p[i].y;
    vx /= 1.0 * (r - l + 1); vy /= 1.0 * (r - l + 1);
    for (int i = l; i <= r; i++)
        sx += (p[i].x - vx) * (p[i].x - vx), sy += (p[i].y - vy) * (p[i].y - vy);
    if (sx >= sy) nth_element(p + l, p + mid, p + r + 1, [](point a, point b) {return a.x < b.x;} );
    else nth_element(p + l, p + mid, p + r + 1, [](point a, point b) {return a.y < b.y;} );
    t[mid].lc = build(l, mid - 1); t[mid].rc = build(mid + 1, r);
    pushup(mid);
    return mid;
}



int query(int u) {
    auto check = [](int x, int y) { return a * x + b * y < c; };
    int tmp = check(t[u].l, t[u].d) + check(t[u].l, t[u].u) + check(t[u].r, t[u].d) + check(t[u].r, t[u].u);
    if (tmp == 0) return 0;
    if (tmp == 4) return t[u].sum;
    int res = 0;
    if (check(p[u].x, p[u].y)) res += p[u].val;
    if (t[u].lc) res += query(t[u].lc);
    if (t[u].rc) res += query(t[u].rc);
    return res;
}

signed main() {
    cin >> n >> m;
    for (int i = 1; i <= n; i++)
        cin >> p[i].x >> p[i].y >> p[i].val;
    int rt = build(1, n);
    for (int i = 1; i <= m; i++) {
        cin >> a >> b >> c;
        cout << query(rt) << "\n";
    }
    return 0;
}

posted @ 2023-03-16 21:59  _vv123  阅读(21)  评论(0编辑  收藏  举报