树链剖分

树链剖分是一种可以 应付毒瘤出题人 将树上问题转换为线性数组问题的算法

另外在观看文章前,普及一下DFS序,

就是对一棵树进行DFS,其中第 \(i\) 个点是第 \(p_i\) 个遍历到的,

那么其 DFS 序就是 \(p_i\)

重链剖分

例题

例题来自 洛谷树链剖分模板题

大意是说:

给定包含 \(n\) 个节点的树,节点 \(i\) 包含一个节点值为 \(a_i\),请维护以下操作:

  • 1 x y z,将 \(x\)\(y\) 的最短路径上的节点值加上 \(z\)
  • 2 x y,询问 \(x\)\(y\) 的最短路径上的节点值之和
  • 3 x z,将以 \(x\) 为根的子树中所有节点的节点值加上 \(z\)
  • 4 x,询问以 \(x\) 为根的子树中所有节点的节点值之和

分析

这道题看起来有些似曾相识的感觉,似乎就有点像线段树

所以就想想如何把这棵树转换成一个序列,

就会想到用 DFS 序,(如果学过替罪羊树的话,可以类比替罪羊树的拍扁重构,把整棵树拍扁)

然后不论是哪种优先条件,对于任意子树,其内部节点的 DFS 序都应是连续的

所以直接线段树维护就行,但是两点间的路径呢?

重链剖分与其性质

这里就要用到重链剖分了:

定义 重子节点 表示其子节点中子树最大的子结点。如果有多个子树最大的子结点,取其一。如果没有子节点,就无重子节点。

定义 轻子节点 表示剩余的所有子结点。

从这个结点到重子节点的边为 重边

到其他轻子节点的边为 轻边

若干条首尾衔接的重边构成 重链

把落单的结点也当作重链,那么整棵树就被剖分成若干条重链。

而对于每一条重链,有四个特殊性质,即:

  • 在剖分时 重边优先遍历,最后树的 DFS 序上,重链内的 DFS 序是连续的
  • 树上每个节点都属于且仅属于一条重链,所有的重链将整棵树 完全剖分
  • 重链开头的结点不一定是重子节点(因为重边是对于每一个结点都有定义的)
  • 所有的重链剖分都是从上到下的

所以就可以将每条重链变为一个树上的区间

但是问题就来了,任意两点的最短路径中有多少条重链呢?

可以发现,当我们向下经过一条 轻边 时,所在子树的大小至少会被除以二

也就是说,从任意一个节点往下,最多走 \(\log n\) 条轻边,也就最多跨过 \(\log n\) 条链

因此,对于树上的任意一条路径,把它拆分成从LCA分别向两边往下走,分别最多走 $\log n $ 次,合起来,树上的每条路径都可以被拆分成不超过 \(2 \log n\) 条重链

总结思路

由于任意子树的 DFS 序连续的,所以任意子树加可以被转化为区间加

同理,任意子树查询被转化为区间查询

而对于路径加,就可以将其路径上的所有重链加一下,

也就是多个区间加,由于路径上最多有 \(2 \log n\) 条重链,

所以是 \(2 \log n\) 次区间加

同理,路径查询变为 \(2 \log n\) 次区间查询

区间加具体来说,可以用线段树或者树状数组维护,复杂度是 \(O(\log n)\)

所以对于后两种子树操作,复杂度为 \(O(\log n)\)

但对于前两种路径操作,由于都被转化为了 \(2 \log n\) 次区间操作,

所以复杂度为 \(O(\log^2 n)\)

最坏的总复杂度为 \(O(m\log^2 n)\)

看起来时间复杂度还是很不错的

代码

#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
const int MAXN=1e5+7;
struct Edge
{
    int v,nxt;
} e[MAXN<<1];
int d[MAXN<<2],b[MAXN<<2],dfn[MAXN],cnt,head[MAXN],tot,a[MAXN];
int dep[MAXN],hs[MAXN],sz[MAXN],tp[MAXN],val[MAXN],fa[MAXN],n,m,r,mod;

inline void add(int u,int v)
{
    e[++tot].v=v;
    e[tot].nxt=head[u];
    head[u]=tot;
}

void dfs1(int u,int f)
{
    fa[u]=f;
    dep[u]=dep[f]+1;
    sz[u]=1;
    for(int i=head[u]; i; i=e[i].nxt)
    {
        int v=e[i].v;
        if(v==f) continue;
        dfs1(v,u);
        sz[u]+=sz[v];
        if(sz[v]>sz[hs[u]]) hs[u]=v;
    }
}

void dfs2(int u,int tpf)
{
    tp[u]=tpf;
    dfn[u]=++cnt;
    val[cnt]=a[u];
    if(!hs[u]) return;
    dfs2(hs[u],tpf);
    for(int i=head[u]; i; i=e[i].nxt)
    {
        int v=e[i].v;
        if(v==fa[u]||v==hs[u])continue;
        dfs2(v,v);
    }
}

inline void pushup(int p)
{
    d[p]=(d[p<<1]+d[(p<<1)+1])%mod;
}

inline void pushdown(int p,int l,int r)
{
    if(!b[p])return;
    int mid=(l+r)>>1;
    b[p<<1]+=b[p];
    b[(p<<1)+1]+=b[p];
    d[p<<1]+=(mid-l+1)*b[p];
    d[(p<<1)+1]+=(r-mid)*b[p];
    b[p]=0;
}

void build(int p,int l,int r)
{
    if(l==r)
    {
        d[p]=val[l];
        return;
    }
    int mid=(l+r)>>1;
    build(p<<1,l,mid),build((p<<1)+1,mid+1,r);
    pushup(p);
}
void update(int p,int l,int r,int u,int v,int a)
{
    if(u<=l&&r<=v)
    {
        d[p]+=(r-l+1)*a;
        b[p]+=a;
        d[p]%=mod;
        return;
    }
    pushdown(p,l,r);
    int mid=(l+r)>>1;
    if(u<=mid) update(p<<1,l,mid,u,v,a);
    if(mid<v) update((p<<1)+1,mid+1,r,u,v,a);
    pushup(p);
}
int query(int p,int l,int r,int u,int v)
{
    if(u<=l&&r<=v)return d[p];
    pushdown(p,l,r);
    int mid=(l+r)>>1,ans=0;
    if(u<=mid) ans+=query(p<<1,l,mid,u,v);
    if(mid<v) ans+=query((p<<1)+1,mid+1,r,u,v);
    return ans%mod;
}
void padd(int u,int v,int a)
{
    a%=mod;
    while(tp[u]!=tp[v])
    {
        if (dep[tp[u]]<dep[tp[v]]) swap(u,v);
        update(1,1,n,dfn[tp[u]],dfn[u],a);
        u=fa[tp[u]];
    }
    if(dep[u]>dep[v]) swap(u,v);
    update(1,1,n,dfn[u],dfn[v],a);
}
int pquery(int u,int v)
{
    int ans=0;
    while(tp[u]!=tp[v])
    {
        if(dep[tp[u]]<dep[tp[v]]) swap(u,v);
        ans+=query(1,1,n,dfn[tp[u]],dfn[u]);
        u=fa[tp[u]];
    }
    if(dep[u]>dep[v]) swap(u,v);
    ans+=query(1,1,n,dfn[u],dfn[v]);
    return ans%mod;
}

inline int read()
{
    int x=0,f=1;
    char ch=getchar();
    while(ch<'0'||ch>'9')
    {
        if(ch=='-') f=-1;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9')
    {
        x=(x<<1)+(x<<3)+(ch^48);
        ch=getchar();
    }
    return x*f;
}

int main()
{
    n=read(),m=read(),r=read(),mod=read();
    for (int i=1; i<=n; i++)a[i]=read();
    for (int i=1; i<n; i++)
    {
        int u=read(),v=read();
        add(u,v);
        add(v,u);
    }
    dfs1(r,r);
    dfs2(r,r);
    build(1,1,n);
    for (int i=1; i<=m; i++)
    {
        int opt=read();
        if(opt==1)
        {
            int u=read(),v=read(),z=read();
            padd(u,v,z);
        }
        else if(opt==2)
        {
            int u=read(),v=read();
            printf("%d\n",pquery(u,v));
        }
        else if(opt==3)
        {
            int u=read(),z=read();
            update(1,1,n,dfn[u],dfn[u]+sz[u]-1,z);
        }
        else
        {
            int u=read();
            printf("%d\n",query(1,1,n,dfn[u],dfn[u]+sz[u]-1));
        }
    }
    return 0;
}

重链剖分求解LCA

这个很简单,

首先考虑当两点在同一条重链时,深度较小者显然就是LCA

但如果两点不在同一条重链上呢?就想把深度较大者不停地往上找,直到两点在同一条重链上为止

变为同一条重链后,就再次回到第一种情况了

代码如下:

int lca(int x,int y)
{
    while(tp[x]!=tp[y])
    {
        if(dep[tp[x]]>dep[tp[y]]) x=fa[tp[x]];
        else y=fa[tp[y]];
    }
    return dep[x]<dep[y]?x:y;
}

时间复杂度是预处理 \(O(n)\),查询最坏情况 \(O(\log n)\)

比倍增的时间复杂度要优化不少

长链剖分

说实话,长链剖分比重链剖分的用途要小很多

但在dp的优化,需要分类讨论的树上问题以及k级祖先上,长链剖分还是大有作用的

这里只讲一下k级祖先

但这纯属一个值得一说,但没必要写的程序,

毕竟这个算法的常数太大,可能还没有重链剖分好用

例题

题目来自 洛谷的k级祖先模板

给定一棵 \(n\) 个点的有根树。

\(q\) 次询问,第 \(i\) 次询问给定 \(x_i, k_i\),要求点 \(x_i\)\(k_i\) 级祖先,答案为 \(ans_i\)。特别地,\(ans_0 = 0\)

本题中的询问将在程序内生成。

给定一个随机种子 \(s\) 和一个随机函数 \(\operatorname{get}(x)\)

#define ui unsigned int
ui s;

inline ui get(ui x) {
	x ^= x << 13;
	x ^= x >> 17;
	x ^= x << 5;
	return s = x; 
}

你需要按顺序依次生成询问。

\(d_i\) 为点 \(i\) 的深度,其中根的深度为 \(1\)

对于第 \(i\) 次询问,\(x_i = ((\operatorname{get}(s) \operatorname{xor} ans_{i-1}) \bmod n) + 1\)\(k_i = (\operatorname{get}(s) \operatorname{xor} ans_{i-1}) \bmod d_{x_i}\)

分析

说实话,这道题的描述有点离谱,上一次见到这种随机生成数据的离谱题目还是ODT模板

考虑回头写一篇ODT的学习笔记

简而言之,就是求一个数据完全随机,但是强制在线的k级祖先

首先考虑重链剖分,这种算法是 \(O(n)\) 预处理,\(O(\log n)\) 回答的

再考虑一下倍增,显然是 \(O(n \log n)\) 预处理,\(O(\log n)\) 回答

如果将这两种算法结合呢?
首先树上倍增预处理出每个点的 \(2^n\) 级祖先,时间复杂度 \(O(n \log n)\)

然后对树进行长链剖分,注意,不是重链剖分

长链剖分的定义:

定义 长子节点 表示其子节点中子树最高的子结点。如果有多个子树最高的子结点,取其一。如果没有子节点,就无长子节点。

定义 短子节点 表示剩余的所有子结点。

从这个结点到长子节点的边为 长边

到其他短子节点的边为 短边

若干条首尾衔接的长边构成 长链

把落单的结点也当作长链,那么整棵树就被剖分成若干条长链。

其实就是把重链剖分的定义换了一下

但注意,从根到任意一个点的路径上最多有 \(\sqrt{n}\) 条长链

所以,对于重链剖分可以轻松解决的问题,用长链剖分反而会更慢,例如LCA

剖分后,再次预处理,对于长度为 \(d\) 的链的最高节点预处理其 \(1\)\(d\) 级祖先,时间复杂度为 \(O(n)\)

查询时找到一个值 \(r\),使得 \(2^r < k <2^{r+1}\)

并找到 \(2^r\) 次祖先所在的长链,如果 \(k\) 级祖先在长链内,那么直接 \(O(1)\) 在 DFS 序里查就好了

否则的话,由于 \(k \leq 2^r + d\),所以直接根据预处理来查就好了

后话

这道题直接重链剖分就很不错了,这种方法说实话实在是麻烦,纯属闲的

这种方法的理论复杂度还是很优的,毕竟是 \(O(n\log n)\) 预处理,\(O(1)\) 查询

但是其常数过大,对于这道题来说,可能甚至比重链剖分还慢

所以这里就不提供长链剖分的代码了,直接贴一个重链剖分求解的代码:

#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
const int MAXN=5e5+7;
struct Edge
{
    int v,nxt;
} e[MAXN];
int dfn[MAXN],cnt,head[MAXN],tot;
int dep[MAXN],hs[MAXN],sz[MAXN],tp[MAXN],fa[MAXN],n,m,r,bd[MAXN];

inline void add(int u,int v)
{
    e[++tot].v=v;
    e[tot].nxt=head[u];
    head[u]=tot;
}

void dfs1(int u,int f)
{
    fa[u]=f;
    dep[u]=dep[f]+1;
    sz[u]=1;
    for(int i=head[u]; i; i=e[i].nxt)
    {
        int v=e[i].v;
        dfs1(v,u);
        sz[u]+=sz[v];
        if(sz[v]>sz[hs[u]]) hs[u]=v;
    }
}

void dfs2(int u,int tpf)
{
    tp[u]=tpf;
    dfn[u]=++cnt;
    bd[cnt]=u;
    if(!hs[u]) return;
    dfs2(hs[u],tpf);
    for(int i=head[u]; i; i=e[i].nxt)
    {
        int v=e[i].v;
        if(v==hs[u])continue;
        dfs2(v,v);
    }
}

#define ui unsigned int
ui s;

inline ui get(ui x)
{
    x ^= x << 13;
    x ^= x >> 17;
    x ^= x << 5;
    return s = x;
}

int kth(int u,int k)
{
    while(k>=dfn[u]-dfn[tp[u]]+1&&u!=r)
    {
        k-=(dfn[u]-dfn[tp[u]]+1);
        u=fa[tp[u]];
    }
    return bd[dfn[u]-k];
}

int main()
{
    scanf("%d%d%ud",&n,&m,&s);
    for(int i=1; i<=n; i++)
    {
        int a;
        scanf("%d",&a);
        if(a==0) r=i;
        else add(a,i);
    }
    dfs1(r,0);
    dfs2(r,r);
    int lst=0;
    long long ans=0;
    for(int i=1; i<=m; i++)
    {
        int x=(get(s)^lst)%n+1,k=(get(s)^lst)%dep[x];
        lst=kth(x,k);
        ans^=(long long)i*lst;
    }
    printf("%lld",ans);
    return 0;
}

posted @ 2022-03-31 21:24  yhang323  阅读(63)  评论(0编辑  收藏  举报