树链剖分总结

树链剖分思想不是很复杂。首先给出几个定义吧:

  • 重儿子:对于每一个非叶子节点,它的儿子中 儿子数量最多的那一个儿子 为该节点的重儿子
  • 轻儿子:对于每一个非叶子节点,它的儿子中 非重儿子 的剩下所有儿子即为轻儿子
  • 重边:连接任意两个重儿子的边叫做重边
  • 轻边:剩下的即为轻边
  • 重链:相邻重边连起来的 连接一条重儿子 的链叫重链

核心思想就是,将一棵树拆成多条链,然后对于每一条链,就用数据结构去维护。
有个不会证明的性质,就是如果将一颗树拆成多条重链和轻边,那么重链的个数不会超过\(log_2n\),轻边的边数也不会超过\(log_2n\)。因为这个性质,很多操作我们可以很高效地完成。这个之后就知道了。

那么我们如何去拆分一颗树呢?通过两次dfs即可解决。
首先第一次dfs,我们可以处理出每个结点的深度\(deep[u]\),以它为根的子树中点的数量\(sz[u]\),以及每个点的父亲结点\(fa[u]\),并且可以求出每个点的重儿子\(son[u]\)
代码如下:

void dfs1(int u, int pa, int d) {
    deep[u] = d;
    fa[u] = pa;
    sz[u] = 1;
    int mx = -1;
    for(int i = head[u]; i != -1; i = e[i].next) {
        int v = e[i].v;
        if(v == pa) continue ;
        dfs1(v, u, d + 1) ;
        sz[u] += sz[v] ;
        if(sz[v] > mx) mx = sz[v], son[u] = v;
    }
}

 
之后进行第二次dfs,这里我们就需要处理出链了。我们知道,dfs序可以将树结构哈希成线性结构,然后方便我们去维护。其实一般树链剖分都要利用dfs序。但是这里因为我们要维护一条链的信息,所以我们应该在dfs的时候优先选择重儿子,并且维护每条链的顶端结点,方便我们后续的操作。
优先选择重儿子就可以保证一条链的dfs序是连续的,方便我们之后进行维护。
第二次dfs的代码如下:

void dfs2(int u, int topf) {
    id[u] = ++cnt;
    v[cnt] = w[u] ;
    top[u] = topf;
    if(!son[u]) return ;
    dfs2(son[u], topf) ;
    for(int i = head[u]; i != -1; i = e[i].next) {
        int v = e[i].v;
        if(v != son[u] && v != fa[u]) dfs2(v, v) ;
    }
}

这样我们就处理出每条链了,top数组保存的就是每条链的顶端结点。对于两点间的路径,我们就可以利用top来加速。
以上就是树链剖分的核心部分吧。其余部分就是相应的数据结构去维护信息了。
现在看起来不是很难,但是以前听都听不懂。。还是以前学习的态度不是很认真吧。

下面给出一道模板题吧:
 

[洛谷P3384](https://www.luogu.org/problemnew/show/P3384)
题目中对于路径的修改和询问,我们就类似于求LCA那样往上面跳就行了,同时在跳的过程中,记得对链上面的信息进行维护。 题目中对于子树的修改和询问,也类似,因为他们的dfs序都是连续的,所以直接通过dfs序来维护信息就好啦。 以下为代码(两次dfs不长,加上线段树就有点长了啊。。):
Code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 2e5 + 5;
int n, m, r, mod;
int v[N];
int deep[N], son[N], fa[N], top[N], id[N], w[N], sz[N];
int cnt ;
int sum[4 * N], lazy[4 * N];
struct Edge{
    int v, next ;
}e[N << 1];
int head[N], tot;
void adde(int u, int v) {
    e[tot].v = v; e[tot].next = head[u]; head[u] = tot++;
    e[tot].v = u; e[tot].next = head[v]; head[v] = tot++;
}
void dfs1(int u, int pa, int d) {
    deep[u] = d;
    fa[u] = pa;
    sz[u] = 1;
    int mx = -1;
    for(int i = head[u]; i != -1; i = e[i].next) {
        int v = e[i].v;
        if(v == pa) continue ;
        dfs1(v, u, d + 1) ;
        sz[u] += sz[v] ;
        if(sz[v] > mx) mx = sz[v], son[u] = v;
    }
}
void dfs2(int u, int topf) {
    id[u] = ++cnt;
    v[cnt] = w[u] ;
    top[u] = topf;
    if(!son[u]) return ;
    dfs2(son[u], topf) ;
    for(int i = head[u]; i != -1; i = e[i].next) {
        int v = e[i].v;
        if(v != son[u] && v != fa[u]) dfs2(v, v) ;
    }
}
void pushup(int o) {
    sum[o] = (sum[o << 1] + sum[o << 1|1]) % mod ;
}
void pushdown(int o, int l, int r) {
    if(lazy[o]) {
        int mid = (l + r) >> 1;
        sum[o << 1] = (sum[o << 1] + lazy[o] * (mid - l + 1)) % mod ;
        sum[o << 1|1] = (sum[o << 1|1] + lazy[o] * (r - mid)) % mod ;
        lazy[o << 1] = (lazy[o << 1] + lazy[o]) % mod ;
        lazy[o << 1|1] = (lazy[o << 1|1] + lazy[o]) % mod ;
        lazy[o] = 0;
    }
}
void build(int o, int l, int r) {
    if(l == r) {
        sum[o] = v[l] ;
        return ;
    }
    int mid = (l + r) >> 1;
    build(o << 1, l, mid) ;
    build(o << 1|1, mid + 1, r) ;
    pushup(o) ;
}
void update(int L, int R, int o, int l, int r,int val) {
    if(L <= l && r <= R) {
        sum[o] = (sum[o] + (r - l + 1) * val) % mod;
        lazy[o] = (lazy[o] + val) % mod;
        return ;
    }
    pushdown(o, l, r) ;
    int mid = (l + r) >> 1;
    if(L <= mid) update(L, R, o << 1, l, mid, val) ;
    if(R > mid) update(L, R, o << 1|1, mid + 1, r, val) ;
    pushup(o) ;
}
int query(int L, int R, int o, int l, int r) {
    if(L <= l && r <= R)
        return sum[o] ;
    pushdown(o, l, r);
    int mid = (l + r) >> 1;
    int ans = 0;
    if(L <= mid) ans = (ans + query(L, R, o << 1, l, mid)) % mod;
    if(R > mid) ans = (ans + query(L, R, o << 1|1, mid + 1, r)) % mod;
    return ans ;
}
int main() {
    ios::sync_with_stdio(false); cin.tie(0);
    cin >> n >> m >> r >> mod;
    memset(head, -1, sizeof(head)) ;
    for(int i = 1; i <= n; i++) cin >> w[i];
    for(int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        adde(u, v);
    }
    dfs1(r, 0, 1) ;
    dfs2(r, r) ;
    build(1, 1, n);
    for(int i = 1; i <= m; i++) {
        int op, x, y, z;
        cin >> op >> x ;
        if(op == 1) {
            cin >> y >> z;
            while(top[x] != top[y]) {
                if(deep[top[x]] < deep[top[y]]) swap(x, y) ;
                update(id[top[x]], id[x], 1, 1, n, z) ;
                x = fa[top[x]] ;
            }
            if(deep[x] < deep[y]) swap(x, y) ;
            update(id[y], id[x], 1, 1, n, z) ;
        } else if(op == 2) {
            cin >> y;
            int ans = 0;
            while(top[x] != top[y]) {
                if(deep[top[x]] < deep[top[y]]) swap(x, y) ;
                ans = (ans + query(id[top[x]], id[x], 1, 1, n)) % mod;
                x = fa[top[x]] ;
            }
            if(deep[x] < deep[y]) swap(x, y) ;
            ans = (ans + query(id[y], id[x], 1, 1, n)) % mod ;
            cout << ans << '\n';
        } else if(op == 3) {
            cin >> z;
            update(id[x], id[x] + sz[x] - 1, 1, 1, n, z) ;
        } else {
            int ans = 0;
            ans = query(id[x], id[x] + sz[x] - 1, 1, 1, n);
            cout << ans << '\n';
        }
    }
    return 0;
}

未完待续...

posted @ 2019-04-22 21:51  heyuhhh  阅读(314)  评论(0编辑  收藏  举报