学习笔记:树链剖分

树链剖分

引入

简单来说,树链剖分就是通过某种方式将一棵树划分为几条链,再利用数据结构来维护树上路径。

具体地讲,可以将树上的任意一条路径划分为不超过 logn 条连续的链,每条链上的点深度互不相同(即是自底向上的一条链,链上所有点的 LCA 为链的一个端点),并且保证划分出的每条链上的节点 DFS 序连续,因此可以方便地用一些维护序列的数据结构(如线段树、树状数组)来维护树上路径的信息。

维护树上路径

首先来看一道树链剖分的板子题。

洛谷 P3384【模板】重链剖分/树链剖分

题目描述

如题,已知一棵包含 N 个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:

  • 1 x y z,表示将树从 xy 结点最短路径上所有节点的值都加上 z

  • 2 x y,表示求树从 xy 结点最短路径上所有节点的值之和。

  • 3 x z,表示将以 x 为根节点的子树内所有节点值都加上 z

  • 4 x 表示求以 x 为根节点的子树内所有节点值之和。

思路

我们给出一些定义:

定义 重子节点 表示其子节点中子树最大的子结点。如果有多个子树最大的子结点,取其一。如果没有子节点,就无重子节点。

定义 轻子节点 表示剩余的所有子结点。

从这个结点到重子节点的边为 重边

到其他轻子节点的边为 轻边

若干条首尾衔接的重边构成 重链

把落单的结点也当作重链,那么整棵树就被剖分成若干条重链。

至于具体做法的话,我们考虑进行两次 dfs。

第一个 DFS 记录每个结点的父节点(father)、深度(deep)、子树大小(size)、重子节点(hson)。

void dfs1(int now, int fat, int deep){
    dep[now] = deep;siz[now] = 1;fa[now] = fat;int maxson = -1;
    for(int i = head[now] ; i != 0 ; i = e[i].nxt){
        int v = e[i].to;
        if(v != fat){
            dfs1(v, now, deep + 1);siz[now] += siz[v];
            if(siz[v] > maxson){
                maxson = siz[v];son[now] = v;
            }
        }
    }
}

第二个 DFS 记录所在链的链顶(top,应初始化为结点本身)、重边优先遍历时的 DFS 序(dfn)、DFS 序对应的节点权值。

void dfs2(int now, int fat, int top){
    dfn[now] = ++tot;wt[tot] = w[now];vis[now] = top;
    if(son[now] != 0){
        dfs2(son[now], now, top);
        for(int i = head[now] ; i != 0 ; i = e[i].nxt){
            int v = e[i].to;
            if(v != fat && v != son[now])dfs2(v, now, v);
        }
    }
}

以下为代码实现。

我们先给出一些定义:

  • fa(x) 表示节点 x 在树上的父亲。
  • dep(x) 表示节点 x 在树上的深度。
  • siz(x) 表示节点 x 的子树的节点个数。
  • son(x) 表示节点 x重儿子
  • vis(x) 表示节点 x 所在重链的顶部节点(深度最小)。
  • dfn(x) 表示节点 xDFS 序,也是其在线段树中的编号。
  • wt(x) 表示 DFS 序所对应的节点权值。

我们进行两遍 DFS 预处理出这些值,其中第一次 DFS 求出 fa(x)dep(x)siz(x)son(x),第二次 DFS 求出 vis(x)dfn(x)wt(x)

现在回顾一下我们要处理的问题:

  • 处理任意两点间路径上的点权和。
  • 处理一点及其子树的点权和。
  • 修改任意两点间路径上的点权。
  • 修改一点及其子树的点权。

1、当我们要处理任意两点间路径时: 设所在链顶端的深度更深的那个点为 x 点。

  • ans 加上 x 点到 x 所在链顶端 这一段区间的点权和。
  • x 跳到 x 所在链顶端的那个点的上面一个点。

不停执行这两个步骤,直到两个点处于一条链上,这时再加上此时两个点的区间和即可。

img

这时我们注意到,我们所要处理的所有区间均为连续编号(新编号),于是想到线段树,用线段树处理连续编号区间和,每次查询的时间复杂度为 O(log2n)

2、处理一点及其子树的点权和:

想到记录了每个非叶子节点的子树大小(含它自己),并且每个子树的新编号都是连续的。

于是直接线段树区间查询即可。

时间复杂度为 O(logn)

代码

#include <iostream>
#define MAXN 100005
#define MAXM 100005
int n, m, r, p, u, v;
int op, x, y, z;
int w[MAXN], wt[MAXN];
struct edge{int to, nxt;}e[MAXN << 1];
int head[MAXN], cnt;
int tree[MAXN << 2], mark[MAXN << 2];
int dep[MAXN], siz[MAXN], son[MAXN], fa[MAXN];
int dfn[MAXN], vis[MAXN];
int tot, ans;
int read(){
    int t = 1, x = 0;char ch = getchar();
    while(!isdigit(ch)){if(ch == '-')t = -1;ch = getchar();}
    while(isdigit(ch)){x = (x << 1) + (x << 3) + (ch ^ 48);ch = getchar();}
    return x * t;
}
void write(int x){
    if(x < 0){putchar('-');x = -x;}
    if(x >= 10)write(x / 10);
    putchar(x % 10 + '0');
}
void pushup(int node){tree[node] = tree[node << 1] + tree[node << 1 | 1];tree[node] %= p;}
void pushdown(int node, int len){
    if(mark[node] != 0){
        tree[node << 1] += mark[node] * (len - (len >> 1));tree[node << 1] %= p;
        mark[node << 1] += mark[node];mark[node << 1] %= p;
        tree[node << 1 | 1] += mark[node] * (len >> 1);tree[node << 1 | 1] %= p;
        mark[node << 1 | 1] += mark[node];mark[node << 1 | 1] %= p;
        mark[node] = 0;
    }
}
void build(int node, int left, int right){
    if(left == right){tree[node] = wt[left];return;}
    int mid = left + right >> 1;
    build(node << 1, left, mid);build(node << 1 | 1, mid + 1, right);
    pushup(node);
}
void update(int node, int left, int right, int l, int r, int k){
    if(l <= left && r >= right){
        tree[node] += k * (right - left + 1);tree[node] %= p;
        mark[node] += k;mark[node] %= p;
        return;
    }
    pushdown(node, right - left + 1);int mid = left + right >> 1;
    if(l <= mid)update(node << 1, left, mid, l, r, k);
    if(r > mid)update(node << 1 | 1, mid + 1, right, l, r, k);
    pushup(node);
}
void query(int node, int left, int right, int l, int r){
    if(l <= left && r >= right){ans += tree[node];ans %= p;return;}
    pushdown(node, right - left + 1);int mid = left + right >> 1;
    if(l <= mid)query(node << 1, left, mid, l, r);
    if(r > mid)query(node << 1 | 1, mid + 1, right, l, r);
}
void add(int u, int v){e[++cnt].to = v;e[cnt].nxt = head[u];head[u] = cnt;}
void swap(int &a, int &b){a ^= b ^= a ^= b;}
void dfs1(int now, int fat, int deep){
    dep[now] = deep;siz[now] = 1;fa[now] = fat;int maxson = -1;
    for(int i = head[now] ; i != 0 ; i = e[i].nxt){
        int v = e[i].to;
        if(v != fat){
            dfs1(v, now, deep + 1);siz[now] += siz[v];
            if(siz[v] > maxson){
                maxson = siz[v];son[now] = v;
            }
        }
    }
}
void dfs2(int now, int fat, int top){
    dfn[now] = ++tot;wt[tot] = w[now];vis[now] = top;
    if(son[now] != 0){
        dfs2(son[now], now, top);
        for(int i = head[now] ; i != 0 ; i = e[i].nxt){
            int v = e[i].to;
            if(v != fat && v != son[now])dfs2(v, now, v);
        }
    }
}
void updtree(int x, int y, int z){
    z %= p;
    while(vis[x] != vis[y]){
        if(dep[vis[x]] < dep[vis[y]])swap(x, y);
        update(1, 1, n, dfn[vis[x]], dfn[x], z);
        x = fa[vis[x]];
    }
    if(dep[x] > dep[y])swap(x, y);
    update(1, 1, n, dfn[x], dfn[y], z);
}
int quetree(int x, int y){
    int res = 0;
    while(vis[x] != vis[y]){
        if(dep[vis[x]] < dep[vis[y]])swap(x, y);
        ans = 0;query(1, 1, n, dfn[vis[x]], dfn[x]);
        res += ans;res %= p;x = fa[vis[x]];
    }
    if(dep[x] > dep[y])swap(x, y);
    ans = 0;query(1, 1, n, dfn[x], dfn[y]);
    res += ans;res %= p;
    return res;
}
void updson(int x, int z){update(1, 1, n, dfn[x], dfn[x] + siz[x] - 1, z);}
int queson(int x){ans = 0;query(1, 1, n, dfn[x], dfn[x] + siz[x] - 1);return ans;}
int main(){
    n = read();m = read();r = read();p = read();
    for(int i = 1 ; i <= n ; i ++)w[i] = read();
    for(int i = 1 ; i < n ; i ++){u = read();v = read();add(u, v);add(v, u);}
    dfs1(r, 0, 1);dfs2(r, 0, r);build(1, 1, n);
    for(int i = 1 ; i <= m ; i ++){op = read();
        if(op == 1){x = read();y = read();z = read();updtree(x, y, z);}
        if(op == 2){x = read();y = read();write(quetree(x, y));putchar('\n');}
        if(op == 3){x = read();z = read();updson(x, z);}
        if(op == 4){x = read();write(queson(x));putchar('\n');}
    }return 0;
}

LCA

不断向上跳重链,当跳到同一条重链上时,深度较小的结点即为 LCA。

向上跳重链时需要先跳所在重链顶端深度较大的那个。

#include <iostream>
#define MAXN 500005
using namespace std;
int n, m, s, x, y;
struct edge{int to, nxt;}e[MAXN << 1];
int head[MAXN], cnt = 1;
int son[MAXN], fa[MAXN], dep[MAXN], siz[MAXN];
int dfn[MAXN], vis[MAXN], tot;
int read(){
    int t = 1, x = 0;char ch = getchar();
    while(!isdigit(ch)){if(ch == '-')t = -1;ch = getchar();}
    while(isdigit(ch)){x = (x << 1) + (x << 3) + (ch ^ 48);ch = getchar();}
    return x * t;
}
void write(int x){
    if(x < 0){putchar('-');x = -x;}
    if(x >= 10)write(x / 10);
    putchar(x % 10 ^ 48);
}
void add(int u, int v){
    cnt++;e[cnt].to = v;e[cnt].nxt = head[u];head[u] = cnt;
    cnt++;e[cnt].to = u;e[cnt].nxt = head[v];head[v] = cnt;
}
void dfs1(int now, int fat, int deep){
    fa[now] = fat;siz[now] = 1;dep[now] = deep;int maxson = -1;
    for(int i = head[now] ; i != 0 ; i = e[i].nxt){
        int v = e[i].to;
        if(v != fat){
            dfs1(v, now, deep + 1);
            siz[now] += siz[v];
            if(siz[v] > maxson)
                maxson = siz[v],son[now] = v;
        }
    }
}
void dfs2(int now, int fat, int top){
    tot++;dfn[now] = tot;vis[now] = top;
    if(son[now] != 0)dfs2(son[now], now, top);
    for(int i = head[now] ; i != 0 ; i = e[i].nxt){
        int v = e[i].to;
        if(v != fat && v != son[now])
            dfs2(v, now, v);
    }
}
int lca(int x, int y){
    while(vis[x] != vis[y]){
        if(dep[vis[x]] >= dep[vis[y]])x = fa[vis[x]];
        else y = fa[vis[y]];
    }
    if(dep[x] < dep[y])return x;
    else return y;
}
int main(){
    n = read();m = read();s = read();
    for(int i = 1 ; i < n ; i ++)
        x = read(),y = read(),add(x, y);
    dfs1(s, 0, 1);dfs2(s, 0, s);
    for(int i = 1 ; i <= m ; i ++)
        x = read(),y = read(),write(lca(x, y)),putchar('\n');
    return 0;
}
posted @   tsqtsqtsq  阅读(4)  评论(0编辑  收藏  举报  
相关博文:
阅读排行:
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 单线程的Redis速度为什么快?
· SQL Server 2025 AI相关能力初探
· AI编程工具终极对决:字节Trae VS Cursor,谁才是开发者新宠?
· 展开说说关于C#中ORM框架的用法!
点击右上角即可分享
微信分享提示