初识树链剖分
首发于摸鱼世界&更好的阅读体验
到现在也只会照着std打板子..
虽然这样,
毒树链剖分还是一个非常优雅的算法。
前置芝士:\(DFS\),线段树
树链剖分可以把树上的区间操作通过把树剖成一条条链,利用线段树等数据结构进行维护,从而达到\(O(nlogn)\)的优秀时间复杂度。
比如这样的操作:
在一棵树上,将\(x\)到\(y\)路径上点的点权加上\(w\),并要求支持查询两个点\(x,y\)路径间的点权和。
乍一看,两个操作都很简单。修改操作可以用树上差分\(O(1)\)乱搞,静态查询可以用\(LCA\)完成。
但是合起来就没有办法了:每次查询之前都需要\(O(n)\)预处理,数据略大直接\(T\)飞。
于是树剖出场了。
区间修改&查询是线段树的强项,但是它只能对一段连续的区间进行查询。于是我们需要想办法让树上需要操作的路径变成一段连续的区间。
引入一个概念:重儿子,也就是一个节点的儿子中\(size\)最大的。连接到重儿子的边即为重边
重儿子组成的链,就是重链。
比如在这棵树中,连续的红边组成的就是一条条重链。我们用\(top[u]\)记录节点\(u\)所在重链的顶端。特别地,没有被重边连接的节点,\(top[u]=u\),即它们所在重链的顶端就是自身。注意到,当\(u\)是一条重链的顶端(\(top[u]=u\))时,它的父节点一定在另一条重链上。
始终记住我们的目标:把在树上区间操作转化为在一段连续的区间进行操作。
考虑如何用\(DFS\)给树上的每个节点在区间内找到一个合适的位置。我们发现,从根节点出发,优先走重边,这样的\(dfs\)序似乎有点特殊。
例如上图,优先走重边的\(dfs\)序为:\(124798356\)。很显然,这样的\(dfs\)序满足同一条重链上的点\(dfs\)序连续。所以用线段树维护的,就是重链上的信息。
这样操作之后,我们可以做到的是:\(O(logn)\)对一条重链上的信息区间修改,区间查询。
对于两个节点\(u,v\),我们可以通过不断地跳重链,直到两个节点在同一条重链上。这个是很好实现的,因为只需要跳到\(fa[top[u]]\),就到了一条新的重链。
代码实现仅树剖部分是不麻烦的。我们需要维护的信息有\(dep\)(节点深度),\(fa\)(父节点),\(son\)(重儿子),\(sz\)(子树节点数,用来判重儿子),这些可以用一次\(dfs\)完成。
void dfs1(int u,int f,int d)//fa,dep,son,sz
{
fa[u]=f;
dep[u]=d;
sz[u]=1;
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(v!=f)
{
dfs1(v,u,d+1);
sz[u]+=sz[v];
if(sz[v]>sz[son[u]])son[u]=v;
}
}
}
接下来,就需要把这棵树每个节点压到线段树维护的序列的一个位置了。就像上文说的一样,按照优先重边的\(dfs\)序压入线段树即可。于是记录一个\(id[i]\)表示原树中节点\(i\)对应的线段树中的下标。\(rk[i]\)反过来记录线段树中下标为\(i\)的原数编号。
由于预处理了父节点,所以\(dfs2\)传参只需要\(u\)(当前节点)和\(t\)(当前重链顶端节点)。在遍历儿子之前先\(dfs2(son[u],t)\),因为\(u\)和\(u\)的重儿子在同一条重链上。接下来才遍历轻(非重)儿子\(v\),但是传参为\(dfs2(v,v)\),因为\(v\)就是新的一条重链的起点。
void dfs2(int u,int t)//top,id,rk
{
top[u]=t;
id[u]=++tot;
rk[tot]=u;
if(!son[u])return;
dfs2(son[u],t);
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(v!=fa[u]&&v!=son[u])
dfs2(v,v);
}
}
再回到最开始的问题:
在一棵树上,将\(x\)到\(y\)路径上点的点权加上\(w\),并要求支持查询两个点\(x,y\)路径间的点权和。
答案就显得很明了了。
如果是查询,先保证\(dep[x]>dep[y]\),然后就和\(LCA\)类似的,利用重链加速:每次把\([top[x],x]\)这条重链的和累加到答案上,再使\(x\)跳到另一条重链上,即\(x=fa[top[x]]\),直到\(x,y\)在同一条重链上,再把两个点之间的信息统计累加一下即可。
int getsum(int x,int y)
{
int res=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
sum=0;
asksum(1,id[top[x]],id[x]);
(res+=sum)%=mod;
x=fa[top[x]];
}
if(id[x]>id[y])swap(x,y);
sum=0;
asksum(1,id[x],id[y]);
(res+=sum)%=mod;
return res;
}
修改同理。
于是我们发现,虽然我们采用了优先重边的\(dfs\)序,但它毕竟遍历的都是自己的儿子节点。所以...还可以支持子树操作。因为一棵子树在重边优先的\(dfs\)序中编号也是连续的。并且这个编号很容易算,因为我们维护了一个\(sz\)信息。所以树中\(x\)节点的子树对应的就是线段树维护的\([id[x],id[x]+sz[x]-1]\)这个区间。
于是还是板子一般的线段树区间修改&查询。
可以注意到线段树部分基本没讲,因为每个人写线段树的方法可能不太一样,蒟蒻我分享的只是树剖的思想。
另外,为什么树剖每次操作是\(O(logn)\)呢?利用线段树的子树操作自然是\(O(logn)\),剩下的就是那个像\(LCA\)一样的跳重链。
证明:从任意节点向根节点跳重链,经过的重链和轻边(非重边)都是\(log\)级别的。
考虑到每走一条轻边,子树大小至少翻倍,否则这就不是条轻边了。于是经过的轻边就最多为\(log_2 n\)条。而重链和轻边的交替出现的,所以数量也在这个级别。
于是每次操作就只有\(O(logn)\)的时间复杂度。
以下是代码
#include<bits/stdc++.h>
#define int long long
#define ls (k<<1)
#define rs (k<<1|1)
using namespace std;
const int N=1e5+10;
struct node
{
int l,r,w,f;
}t[N<<2];
int a[N];
int n,m,r,mod;
int sum;
int head[N<<1],to[N<<1],nxt[N<<1],cnt;
int sz[N],fa[N],dep[N],son[N];
int top[N],id[N],rk[N],tot;
inline int read()
{
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
return x*f;
}
void add(int u,int v)
{
cnt++;
to[cnt]=v;
nxt[cnt]=head[u];
head[u]=cnt;
}
void dfs1(int u,int f)
{
fa[u]=f;
sz[u]=1;
dep[u]=dep[f]+1;
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(v==f)continue;
dfs1(v,u);
sz[u]+=sz[v];
if(sz[v]>sz[son[u]])son[u]=v;
}
return;
}
void dfs2(int u,int t)
{
top[u]=t;
id[u]=++tot;
rk[tot]=u;
if(!son[u])return;
dfs2(son[u],t);
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(v!=fa[u]&&v!=son[u])dfs2(v,v);//新的重链
}
}
void build(int k,int l,int r)
{
t[k].l=l,t[k].r=r;
if(l==r)
{
t[k].w=a[rk[l]];
return;
}
int m=l+r>>1;
build(ls,l,m);
build(rs,m+1,r);
t[k].w=t[ls].w+t[rs].w;
return;
}
void down(int k)
{
t[ls].w+=(t[ls].r-t[ls].l+1)*t[k].f;
t[rs].w+=(t[rs].r-t[rs].l+1)*t[k].f;
t[ls].f+=t[k].f;
t[rs].f+=t[k].f;
t[k].f=0;
}
void addsum(int k,int x,int y,int p)
{
int l=t[k].l,r=t[k].r;
if(x<=l&&r<=y)
{
t[k].w+=(r-l+1)*p;
t[k].f+=p;
return;
}
down(k);
int m=l+r>>1;
if(x<=m)addsum(ls,x,y,p);
if(y>m)addsum(rs,x,y,p);
t[k].w=t[ls].w+t[rs].w;
return;
}
void asksum(int k,int x,int y)
{
int l=t[k].l,r=t[k].r;
if(x<=l&&r<=y)
{
sum+=t[k].w;
return;
}
down(k);
int m=l+r>>1;
if(x<=m)asksum(ls,x,y);
if(y>m)asksum(rs,x,y);
t[k].w=t[ls].w+t[rs].w;
return;
}
//-----------------------------
int getsum(int x,int y)
{
int res=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
sum=0;
asksum(1,id[top[x]],id[x]);
(res+=sum)%=mod;
x=fa[top[x]];
}
if(id[x]>id[y])swap(x,y);
sum=0;
asksum(1,id[x],id[y]);
(res+=sum)%=mod;
return res;
}
void update(int x,int y,int p)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
addsum(1,id[top[x]],id[x],p);
x=fa[top[x]];
}
if(id[x]>id[y])swap(x,y);
addsum(1,id[x],id[y],p);
return;
}
signed main()
{
n=read(),m=read(),r=read(),mod=read();
for(int i=1;i<=n;i++)a[i]=read();
for(int i=1;i<n;i++)
{
int x=read(),y=read();
add(x,y),add(y,x);
}
dfs1(r,0);
dfs2(r,r);
build(1,1,n);
for(int i=1;i<=m;i++)
{
int x,y,z;
int opt=read();
if(opt==1)
{
x=read(),y=read(),z=read();
update(x,y,z);
}
if(opt==2)
{
x=read(),y=read();
printf("%lld\n",getsum(x,y)%mod);
}
if(opt==3)
{
x=read(),z=read();
addsum(1,id[x],id[x]+sz[x]-1,z);
}
if(opt==4)
{
x=read();
sum=0;asksum(1,id[x],id[x]+sz[x]-1);
printf("%lld\n",sum%mod);
}
}
return 0;
}
代码的确是长,也不算容易调,但是真正妙的是利用轻重链的思想进行的化树为链。
感谢阅读。