点分治 学习笔记

引入

在点分治的过程中,它的遍历顺序会遍历每棵子树的重心,而这棵由重心生成的树会产生一棵新的树,便是点分树。

常用来解决树上与树的形态无关的路径问题。

过程

如下图,它的点分树是它自己。


因此可以看出一棵树的点分树可能是它本身。

性质

  1. 因为点分树 \(\mathcal{O}(\log n)\) 因此很显然的是点分树树高也是 \(\mathcal{O}(\log n)\) 的。
  2. 树上两点的 \(LCA\) 必定在原树两点路径上。

根据第一条性质,点分树就可以解决一些复杂度瓶颈在树的形态上的且问题与树的形态无关的题。

例题

【模板】点分治 | 震波

简要题意

给定 \(n(1 \le n \le 10 ^ 5)\) 个点的树,每个点 \(i\) 点权为 \(w_i (0 \le w_i \le 10^4)\)\(m(1 \le m \le 10^5)\) 次操作。

每次操作有修改或查询。

修改操作,将第 \(i\) 个点的点权改为 \(v(1 \le v \le 10^4)\)
查询操作,查询距离 \(i\) 距离小于等于 \(k(0 \le k \le n - 1)\) 的点的点权和。

思路

点分治模板题,建立出点分树后,只需要处理出每个点的子树中距离它为 \(k\) 的点的点权和和它的父亲距离它的子树中距离为 \(k\) 的点权和。

之后修改就暴力跳父亲,用树状数组修改;查询也暴力跳父亲查询即可。

代码
#include <cstdio>
#include <vector>
#include <bitset>
#include <queue>
#include <algorithm>
#include <iostream>

using u32 = unsigned int ;
using i64 = long long ;
using u64 = unsigned long long ;

const int N = 1e5 + 5 ;

struct FenwickTree{
    std::vector<int> t;
    u32 siz;

    void resize(const int & n){
        siz = n + 5;
        t.resize(n + 10, 0);
    }

    int lowbit(const int & x){
        return x & -x;
    }
    
    void modify(u32 x, int val){
        ++x;
        for (; x <= siz; x += lowbit(x)) {
            t[x] += val;
        }
    }
    
    int query(u32 x){
        int ans = 0;
        
        ++x;
        x = std::min(x, siz);
        for (; x; x -= lowbit(x)) {
            ans += t[x];
        }

        return ans;
    }
}t[2][N];
// t[0][i] 第 i 个点的子树距离它小于等于 k 的权值和, t[1][i] 第 i 个点的父亲距离它子树小于等于 k 的权值和

int n, m;
int val[N];

std::vector<int> g[N];
int dfa[N];

bool vis[N];

int siz[N], rt;
void GetRoot(int u, int FA, int sum){
    siz[u] = 1;
    int maxn = 0;

    for (auto v : g[u]) {
        if (v != FA && !vis[v]) {
            GetRoot(v, u, sum);
            siz[u] += siz[v];
            maxn = std::max(maxn, siz[v]);
        }
    }

    maxn = std::max(maxn, sum - siz[u]);
    if (maxn * 2 <= sum) {
        rt = u;
    }
}

// ----------------------------- 树链剖分
int fa[N], dep[N];
int son[N];
void dfs1(int u, int FA){
    siz[u] = 1;
    fa[u] = FA;
    dep[u] = dep[FA] + 1;

    for (auto v : g[u]) {
        if (v != FA) {
            dfs1(v, u);
            siz[u] += siz[v];
            if (siz[son[u]] < siz[v]) {
                son[u] = v;
            }
        }
    }
}

int top[N];
void dfs2(int u, int x){
    top[u] = x;

    if (!son[u]) {
        return;
    }

    dfs2(son[u], x);

    for (auto v : g[u]) {
        if (v != son[u] && v != fa[u]) {
            dfs2(v, v);
        }
    }
}
// ---------------------------- end

void dfs(int u){
    vis[u] = true;

    t[0][u].resize(siz[u]);

    for (auto v : g[u]) {
        t[1][v].resize(siz[v]);
    }

    for (auto v : g[u]) {
        if (!vis[v]) {
            GetRoot(v, u, siz[v]);
            GetRoot(rt, 0, siz[v]);

            dfa[rt] = u;
            dfs(rt);
        }
    }
}

int LCA(int x, int y){
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]]) {
            std::swap(x, y);
        }
        x = fa[top[x]];
    }
    return dep[x] < dep[y]? x: y;
}

int dist(int x, int y){
    return dep[x] + dep[y] - (dep[LCA(x, y)] << 1);
}

int query(int x, int k){
    int res = t[0][x].query(k);
    int cur = x;

    while (dfa[cur]) {
        int d = dist(dfa[cur], x);
        if (d > k) {
            cur = dfa[cur];
            continue;
        }

        res += t[0][dfa[cur]].query(k - d);
        res -= t[1][cur].query(k - d);

        cur = dfa[cur];
    }

    return res;
}

void modify(int x, int v){
    int cur = x;
    
    while (cur) {
        t[0][cur].modify(dist(x, cur), v);
        if (dfa[cur]) {
            t[1][cur].modify(dist(x, dfa[cur]), v);
        }

        cur = dfa[cur];
    }
}

int main(){
    scanf("%d %d", &n, &m);

    for (int i = 1; i <= n; i++) {
        scanf("%d", val + i);
    }

    for (int i = 1; i < n; i++) {
        int u, v;
        scanf("%d %d", &u, &v);

        g[u].push_back(v);
        g[v].push_back(u);
    }

    dfs1(1, 0);
    dfs2(1, 1);

    GetRoot(1, 0, n);
    GetRoot(rt, 0, n);
    dfs(rt);

    for (int i = 1; i <= n; i++) {
        modify(i, val[i]);
    }

    int lst = 0;
    for (int i = 1; i <= m; i++) {
        int op, x, y;
        scanf("%d %d %d", &op, &x, &y);

        x ^= lst, y ^= lst;

        if (op == 0) {
            printf("%d\n", lst = query(x, y));
        } else {
            modify(x, y - val[x]);
            val[x] = y;
        }
    }
    return 0;
}
posted @ 2024-06-07 16:22  Z_drj  阅读(6)  评论(0编辑  收藏  举报