[算法入门]树链剖分 - 轻重链剖分

#0.0 前置知识

为了保证本文的阅读愉快,建议熟练掌握以下知识:



#1.0 啥是树剖 & 树剖能干啥

树链剖分用于将树分割成若干条链的形式,以维护树上路径的信息。

具体来说,将整棵树剖分为若干条链,使它组合成线性结构,然后用其他的数据结构维护信息。

重链剖分可以将树上的任意一条路径划分成不超过 \(\log n\) 条连续的链,每条链上的点深度互不相同(即是自底向上的一条链,链上所有点的 LCA 为链的一个端点)。

重链剖分还能保证划分出的每条链上的节点 DFS 序连续,因此可以方便地用一些维护序列的数据结构(如线段树)来维护树上路径的信息。

如:

  • 修改 树上两点之间的路径上 所有点的值。
  • 查询 树上两点之间的路径上 节点权值的 和/极值/其它(在序列上可以用数据结构维护,便于合并的信息)

除了配合数据结构来维护树上路径信息,树剖还可以用来 (且常数较小)地求 LCA。在某些题目中,还可以利用其性质来灵活地运用树剖。

#2.0 重链剖分

正如文题,本文主要讲述重链剖分的实现

#2.1 基本变量 & 定义

重儿子: 所有儿子中子树结点数最多的儿子

重边:由结点与该点重儿子组成的树边

重链:由若干条重边首尾相连组成的长链

  • depth[x] 结点 \(x\) 的深度
  • f[x] 结点 \(x\) 的父结点编号
  • son[x] 结点 \(x\) 的重儿子编号
  • size[x]\(x\) 为根的结点个数(包括自己)
  • top[x] \(x\) 所在重链的顶部结点
  • id[x],rk[y] 字典(?)数组,将原树中的结点映射到线段树中的结点与它的逆操作

#2.2 读入 & 储存

我们这里采用邻接表储存双向边(每一条树边存两个方向)

#2.3 解剖的实现 - 1

这是树剖的核心操作一

这一部分,我们主要是要完成以下任务:

  • 搞清出每个结点的深度
  • 摸清父子关系,统计出以以 \(x\) 为根的结点个数 size[x]
  • 找到重儿子

很显然,以上三个任务可以通过一次 dfs 解决

void dfs1(int x,int fa,int depth){
    f[x] = fa;
    d[x] = depth;
    size[x] = 1;
    for (int i = head[x];i != -1;i = e[i].next){
        int v = e[i].v;
        if (v == fa) //注意,因为是双向边,会有连向父结点的边
          continue;
        dfs1(v,x,depth + 1); //递归求解子结点
        size[x] += size[v];
        if (size[v] > size[son[x]]) //更新重儿子
          son[x] = v;
    }
}

#2.4 解剖的实现 - 2

#2.4.1 目标及实现

这次,我们要完成的任务有

  • 确定所在重链的头
  • 捋清与线段树上结点编号的关系

显然,依旧可以采用 dfs 实现,这里有几个值得强调的点:

  • 我们将一个落单的点看做一个新的重链,并将它作为以他为根的重儿子所在重链的头
  • 优先递归重儿子(原因在下面)
void dfs2(int u,int t){
    top[u] = t;
    id[u] = tot;
    rk[tot ++] = u;
    if (!son[u]) //叶子结点,直接返回
      return;
    dfs2(son[u],t); //优先递归重儿子
    for (int i = head[u];i != -1;i = e[i].next){
        int v= e[i].v;
        if (v != son[u] && v != f[u])
          dfs2(v,v); //注意不要重复递归与错误递归
    }
}

#2.4.2 对于某些问题の解释

Q:为啥优先递归重儿子?

别着急,我们先来看看如果这样做会有怎样的效果

显然,这样做一条重链上的所有结点对应的编号连续(dfs 序连续),这样的话,如果我们要按照 dfs 序建立一颗线段树,那么一条重链对应的结点编号显然是连续的,我们如果要对一条重链进行修改,可以直接修改一个区间

#3.0 求 LCA

好的,经过以上两场手术,一颗完整的树已经被解剖成为大大小小的链了,在进行更深一步的探索前,我们先来学习一个必备的操作——用重链求树上两点的 LCA

#3.1 简单分析

首先,我们考虑什么情况可以直接返回 \(x,y\) 的 LCA?

显然是二者在同一条重链上时,那么,为了达到这个目标,我们只需要每次让所在重链的头深度较大的结点向上跳,跳到该重链的头的父节点,重复这个过程即可

#3.2 实现

inline int lca(int a,int b){
    while (top[a] != top[b]){
        if (d[top[a]] < d[top[b]])
          swap(a,b);
        a = f[top[a]];
    }
    if (d[a] < d[b])
      swap(a,b);
    return b;
}

#4.0 利用数据结构维护信息

这里以线段树为例(因为我不会其他的

假如我们要维护这样一棵树:

  • 可为从 \(a\)\(b\) 的路径上的每个结点加 \(x\)
  • 为以 \(a\) 为根的子树上的所有结点加 \(x\)
  • 求从 \(a\)\(b\) 的路径上的每个结点的权值和
  • \(a\) 为根的子树上的所有结点的权值和

让我们一样一样来

#4.1 路径修改

首先,从 \(a\)\(b\) 的路径是什么?

是不是就是从 \(a\)\(\text{LCA}(a,b)\),再从 \(\text{LCA}(a,b)\)\(b\),那么,我们可不可以借用上面求 LCA 的思想,每跳一次,给重链进行区间修改,最终达到修改整个路径

还记得我们解剖时的一个操作吗?让一条重链上的结点对应的线段树上的编号连续,所以就可以直接进行区间修改

inline void update(int l,int r,int x,int k){
    if (l <= a[k].l && r >= a[k].r){
        add(k,x);
        return;
    }
    pushdown(k);
    int mid = (a[k].l + a[k].r) >> 1;
    if (mid >= l)
      update(l,r,x,a[k].ls);
    if (mid < r)
      update(l,r,x,a[k].rs);
    pushup(k);
}

inline void updates(int x,int y,int c){
    while (top[x] != top[y]){
        if ( d[top[x]] < d[top[y]])
          swap(x,y);
        update(id[top[x]],id[x],c,rt);
        x = f[top[x]];
    }
    if (id[x] > id[y])
      swap(x,y);
    update(id[x],id[y],c,rt);
}

#4.2 子树修改

首先,dfs 序具有这样一个性质:一颗子树上的点 dfs 序连续

所以我们就可以直接采用如下代码进行区间修改

update(id[x],id[x] + size[x] - 1,y,rt);

#4.3 路径查询

与路径修改一样的思想

inline int query(int l,int r,int k){
    if (l <= a[k].l && a[k].r <= r)
      return a[k].sum;
    pushdown(k);
    int mid = (a[k].l + a[k].r) >> 1;
    int ans = 0;
    if (mid >= l)
      (ans += query(l,r,a[k].ls)) %= mod;
    if (mid < r)
      (ans += query(l,r,a[k].rs)) %= mod;
    return ans;
}

inline int sum(int x,int y){
    int ans = 0;
    while (top[x] != top[y]){
        if (d[top[x]] < d[top[y]])
          swap(x,y);
        (ans += query(id[top[x]],id[x],rt)) %= mod;
        x = f[top[x]];
    }
    if (id[x] > id[y])
      swap(x,y);
    return (ans + query(id[x],id[y],rt)) % mod;
}

#4.4 子树查询

与子树修改类似

#5.0 例题

洛谷 P3384 【模板】轻重链剖分

板子题

#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <queue>
#define INF 0x3fffffff
#define N 100100
#define ll long long
#define mset(l,x) memset(l,x,sizeof(l))
#define mp(a,b) make_pair(a,b)
using namespace std;

struct Edge{
    int u;
    int v;
    int next;
};
Edge e[N * 4];

struct Node{
    int l,r;   //区间左右端点 
    int ls,rs; //左右子树根节点编号 
    int sum;   //区间和 
    int lazy;
};
Node a[N * 4];

int n,m,root,mod;
int value[N],head[N];
int d[N],f[N],size[N],son[N],tot;
int top[N],id[N],rk[N],rt; 

inline void add_edge(int u,int v){
    e[tot].u = u;
    e[tot].v = v;
    e[tot].next = head[u];
    head[u] = tot ++;
}

void dfs1(int x,int fa,int depth){
    f[x] = fa;
    d[x] = depth;
    size[x] = 1;
    for (int i = head[x];i != -1;i = e[i].next){
        int v = e[i].v;
        if (v == fa)
          continue;
        dfs1(v,x,depth + 1);
        size[x] += size[v];
        if (size[v] > size[son[x]])
          son[x] = v;
    }
}

void dfs2(int u,int t){
    top[u] = t;
    id[u] = tot;
    rk[tot ++] = u;
    if (!son[u])
      return;
    dfs2(son[u],t);
    for (int i = head[u];i != -1;i = e[i].next){
        int v= e[i].v;
        if (v != son[u] && v != f[u])
          dfs2(v,v);
    }
}

inline int len(int k){
    return a[k].r - a[k].l + 1;
}

inline void pushup(int k){
    a[k].sum = (a[a[k].ls].sum + a[a[k].rs].sum) % mod;
}

inline void add(int k,int x){
    (a[k].lazy += x) %= mod;
    (a[k].sum += len(k) * x) %=  mod;
}

inline void build(int l,int r,int k){
    if (l == r){
        a[k].sum = value[rk[l]];
        a[k].l = a[k].r = l;
        return;
    }
    int mid = (l + r) >> 1;
    a[k].ls = tot ++;
    build(l,mid,a[k].ls);
    a[k].rs = tot ++;
    build(mid + 1,r,a[k].rs);
    a[k].l = a[a[k].ls].l,a[k].r = a[a[k].rs].r;
    pushup(k);
}

inline void pushdown(int k){
    if (a[k].lazy){
        int ls = a[k].ls,rs = a[k].rs;
        add(ls,a[k].lazy);
        add(rs,a[k].lazy);
        a[k].lazy = 0;
    }
}

inline void update(int l,int r,int x,int k){
    if (l <= a[k].l && r >= a[k].r){
        add(k,x);
        return;
    }
    pushdown(k);
    int mid = (a[k].l + a[k].r) >> 1;
    if (mid >= l)
      update(l,r,x,a[k].ls);
    if (mid < r)
      update(l,r,x,a[k].rs);
    pushup(k);
}

inline int query(int l,int r,int k){
    if (l <= a[k].l && a[k].r <= r)
      return a[k].sum;
    pushdown(k);
    int mid = (a[k].l + a[k].r) >> 1;
    int ans = 0;
    if (mid >= l)
      (ans += query(l,r,a[k].ls)) %= mod;
    if (mid < r)
      (ans += query(l,r,a[k].rs)) %= mod;
    return ans;
}

inline int sum(int x,int y){
    int ans = 0;
    while (top[x] != top[y]){
        if (d[top[x]] < d[top[y]])
          swap(x,y);
        (ans += query(id[top[x]],id[x],rt)) %= mod;
        x = f[top[x]];
    }
    if (id[x] > id[y])
      swap(x,y);
    return (ans + query(id[x],id[y],rt)) % mod;
}

inline void updates(int x,int y,int c){
    while (top[x] != top[y]){
        if ( d[top[x]] < d[top[y]])
          swap(x,y);
        update(id[top[x]],id[x],c,rt);
        x = f[top[x]];
    }
    if (id[x] > id[y])
      swap(x,y);
    update(id[x],id[y],c,rt);
}

int main(){
    mset(head,-1);
    scanf("%d%d%d%d",&n,&m,&root,&mod);
    for (int i = 1;i <= n;i ++)
      scanf("%d",&value[i]);
    for (int i = 1;i < n;i ++){
        int u,v;
        scanf("%d%d",&u,&v);
        add_edge(u,v);
        add_edge(v,u);
    }
    
    dfs1(root,0,1);
    tot = 1;
    dfs2(root,root);
    tot = 0;
    build(1,n,rt = tot ++);
    
    while (m --){
        int s;
        scanf("%d",&s);
        if (s == 1){
            int x,y,z;
            scanf("%d%d%d",&x,&y,&z);
            updates(x,y,z);
        }
        else if (s == 2){
            int x,y;
            scanf("%d%d",&x,&y);
            printf("%d\n",sum(x,y));
        }
        else if (s == 3){
            int x,y;
            scanf("%d%d",&x,&y);
            update(id[x],id[x] + size[x] - 1,y,rt);
        }
        else{
            int x;
            scanf("%d",&x);
            printf("%d\n",query(id[x],id[x] + size[x] - 1,rt));
        }
    }
    return 0;
}

参考资料

OI-wiki - 树链剖分

posted @ 2021-03-15 19:59  Dfkuaid  阅读(70)  评论(0编辑  收藏  举报