树链剖分是个很简单的算法
树链剖分一共分为两种,一种是重链剖分,比较常见;还有一种是长链剖分,比较少见
一.重链剖分
重儿子:对于每一个非叶子节点,它的儿子中 以那个儿子为根的子树节点数最大的儿子 为该节点的重儿子 (Ps: 感谢@shzr大佬指出我此句话的表达不严谨qwq, 已修改)
轻儿子:对于每一个非叶子节点,它的儿子中 非重儿子 的剩下所有儿子即为轻儿子
叶子节点没有重儿子也没有轻儿子(因为它没有儿子。。)
重边:一个父亲连接他的重儿子的边称为重边 //原写法:连接任意两个重儿子的边叫做重边
轻边:剩下的即为轻边
重链:相邻重边连起来的 连接一条重儿子 的链叫重链
对于叶子节点,若其为轻儿子,则有一条以自己为起点的长度为1的链
每一条重链以轻儿子为起点
这图好像是洛咕上的,我还是懒得自己画
说起来这些概念实际很简单
但写起来还是要有较强码力的
我们先要写把轻重链求出的函数
一共需要写两个函数
1.dfs1
dfs1主要求出:
1.该节点的子树大小(1+所有子节点子树大小之和)
2.重儿子(找到所有子节点中子树大小最大的)
3.父节点
4.深度
dfs1还是比较简单的qaq
inline void dfs1(register int x)
{
size[x]=1;
for(register int i=head[x];i;i=e[i].next)
if(e[i].to!=fa[x])
{
dep[e[i].to]=dep[x]+1;
fa[e[i].to]=x;
dfs1(e[i].to);
size[x]+=size[e[i].to];
if(size[e[i].to]>size[son[x]])
son[x]=e[i].to;
}
}
dfs2
dfs2是重链剖分的重点
dfs2要求出:
1.树的dfs序(优先搜重儿子)
2.在树的dfs序之下,珂以把树上的值存到连续的数列中,到时就珂以线段树维护
3.每个重链的顶端,方便到时候跳链(不懂的话后面会讲)
inline void dfs2(register int x,register int t)
{
dl[x]=++tot;
a[tot]=ch[x];
top[x]=t;
if(son[x])
dfs2(son[x],t);
for(register int i=head[x];i;i=e[i].next)
if(e[i].to!=fa[x]&&e[i].to!=son[x])
dfs2(e[i].to,e[i].to);
}
跑完两个dfs之后就珂以用线段树
build建树:
inline void pushup(register int x)
{
sum[x]=sum[x<<1]+sum[x<<1|1];
sum[x]%=mod;
}
inline void build(register int x,register int l,register int r)
{
if(l==r)
{
sum[x]=a[l];
tag[x]=0;
return;
}
int mid=l+r>>1;
build(x<<1,l,mid);
build(x<<1|1,mid+1,r);
pushup(x);
}
下面是处理查询
操作1:把x节点到y节点路径上的值加z
这里需要一个跳链的函数——cal1
inline void pushdown(register int x,register int l,register int r)
{
int ls=x<<1,rs=x<<1|1,mid=l+r>>1;
sum[ls]+=(mid-l+1)*tag[x];
sum[rs]+=(r-mid)*tag[x];
tag[ls]+=tag[x];
tag[rs]+=tag[x];
sum[ls]%=mod;
sum[rs]%=mod;
tag[ls]%=mod;
tag[rs]%=mod;
tag[x]=0;
}
inline void update(register int x,register int l,register int r,register int L,register int R,register int k)
{
if(L<=l&&r<=R)
{
sum[x]+=(r-l+1)*k;
tag[x]+=k;
sum[x]%=mod;
tag[x]%=mod;
return;
}
if(tag[x])
pushdown(x,l,r);
int mid=l+r>>1;
if(L<=mid)
update(x<<1,l,mid,L,R,k);
if(R>=mid+1)
update(x<<1|1,mid+1,r,L,R,k);
pushup(x);
}
inline void cal1(register int x,register int y,register int z)
{
int fx=top[x],fy=top[y];
while(fx!=fy)
{
if(dep[fx]<dep[fy])
{
swap(x,y);
swap(fx,fy);
}
update(1,1,tot,dl[fx],dl[x],z);
x=fa[fx];
fx=top[x];
}
if(dl[x]>dl[y])
swap(x,y);
update(1,1,tot,dl[x],dl[y],z);
}
操作2:查询x到y路径点权之和
和操作1差不多,需要跳链
inline ll query(register int x,register int l,register int r,register int L,register int R)
{
if(L<=l&&r<=R)
return sum[x];
if(tag[x])
pushdown(x,l,r);
ll res=0;
int mid=l+r>>1;
if(L<=mid)
res+=query(x<<1,l,mid,L,R)%mod;
if(R>=mid+1)
res+=query(x<<1|1,mid+1,r,L,R)%mod;
return res%mod;
}
inline ll cal2(register int x,register int y)
{
ll res=0;
int fx=top[x],fy=top[y];
while(fx!=fy)
{
if(dep[fx]<dep[fy])
{
swap(x,y);
swap(fx,fy);
}
res=(res%mod+query(1,1,tot,dl[fx],dl[x])%mod)%mod;
x=fa[fx];
fx=top[x];
}
if(dl[x]>dl[y])
swap(x,y);
res=(res%mod+query(1,1,tot,dl[x],dl[y])%mod)%mod;
return res%mod;
}
操作3:把x的子树内所有节点全值加z
考虑到子树内dfs序是相连的
所以被修改区间是一个连续的区间,所以直接上线段树
update(1,1,tot,dl[x],dl[x]+size[x]-1,z%mod);
操作四:求x的子树内所有节点的和
和操作3一样,珂以直接用线段树
write(query(1,1,tot,dl[x],dl[x]+size[x]-1)%mod);
最后上一下重链剖分整体代码
#include <bits/stdc++.h>
#define ll long long
#define N 100005
using namespace std;
inline ll read()
{
register ll x=0,f=1;register char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return x*f;
}
inline void write(register ll x)
{
if(!x)putchar('0');if(x<0)x=-x,putchar('-');
static int sta[36];int tot=0;
while(x)sta[tot++]=x%10,x/=10;
while(tot)putchar(sta[--tot]+48);
}
struct node{
int to,next;
}e[N<<1];
int head[N],cnt=0;
inline void add(register int u,register int v)
{
e[++cnt]=(node){v,head[u]};
head[u]=cnt;
}
ll ch[N];
ll n,m,rt,mod;
ll size[N],dep[N],fa[N],son[N];
ll tot=0,dl[N],a[N],top[N];
inline void dfs1(register int x)
{
size[x]=1;
for(register int i=head[x];i;i=e[i].next)
if(e[i].to!=fa[x])
{
dep[e[i].to]=dep[x]+1;
fa[e[i].to]=x;
dfs1(e[i].to);
size[x]+=size[e[i].to];
if(size[e[i].to]>size[son[x]])
son[x]=e[i].to;
}
}
inline void dfs2(register int x,register int t)
{
dl[x]=++tot;
a[tot]=ch[x];
top[x]=t;
if(son[x])
dfs2(son[x],t);
for(register int i=head[x];i;i=e[i].next)
if(e[i].to!=fa[x]&&e[i].to!=son[x])
dfs2(e[i].to,e[i].to);
}
ll sum[N<<3],tag[N<<3];
inline void pushup(register int x)
{
sum[x]=sum[x<<1]+sum[x<<1|1];
sum[x]%=mod;
}
inline void build(register int x,register int l,register int r)
{
if(l==r)
{
sum[x]=a[l];
tag[x]=0;
return;
}
int mid=l+r>>1;
build(x<<1,l,mid);
build(x<<1|1,mid+1,r);
pushup(x);
}
inline void pushdown(register int x,register int l,register int r)
{
int ls=x<<1,rs=x<<1|1,mid=l+r>>1;
sum[ls]+=(mid-l+1)*tag[x];
sum[rs]+=(r-mid)*tag[x];
tag[ls]+=tag[x];
tag[rs]+=tag[x];
sum[ls]%=mod;
sum[rs]%=mod;
tag[ls]%=mod;
tag[rs]%=mod;
tag[x]=0;
}
inline void update(register int x,register int l,register int r,register int L,register int R,register int k)
{
if(L<=l&&r<=R)
{
sum[x]+=(r-l+1)*k;
tag[x]+=k;
sum[x]%=mod;
tag[x]%=mod;
return;
}
if(tag[x])
pushdown(x,l,r);
int mid=l+r>>1;
if(L<=mid)
update(x<<1,l,mid,L,R,k);
if(R>=mid+1)
update(x<<1|1,mid+1,r,L,R,k);
pushup(x);
}
inline ll query(register int x,register int l,register int r,register int L,register int R)
{
if(L<=l&&r<=R)
return sum[x];
if(tag[x])
pushdown(x,l,r);
ll res=0;
int mid=l+r>>1;
if(L<=mid)
res+=query(x<<1,l,mid,L,R)%mod;
if(R>=mid+1)
res+=query(x<<1|1,mid+1,r,L,R)%mod;
return res%mod;
}
inline void cal1(register int x,register int y,register int z)
{
int fx=top[x],fy=top[y];
while(fx!=fy)
{
if(dep[fx]<dep[fy])
{
swap(x,y);
swap(fx,fy);
}
update(1,1,tot,dl[fx],dl[x],z);
x=fa[fx];
fx=top[x];
}
if(dl[x]>dl[y])
swap(x,y);
update(1,1,tot,dl[x],dl[y],z);
}
inline ll cal2(register int x,register int y)
{
ll res=0;
int fx=top[x],fy=top[y];
while(fx!=fy)
{
if(dep[fx]<dep[fy])
{
swap(x,y);
swap(fx,fy);
}
res=(res%mod+query(1,1,tot,dl[fx],dl[x])%mod)%mod;
x=fa[fx];
fx=top[x];
}
if(dl[x]>dl[y])
swap(x,y);
res=(res%mod+query(1,1,tot,dl[x],dl[y])%mod)%mod;
return res%mod;
}
int main()
{
n=read(),m=read(),rt=read(),mod=read();
for(register int i=1;i<=n;++i)
ch[i]=read(),ch[i]%=mod;
for(register int i=1;i<n;++i)
{
int u=read(),v=read();
add(u,v),add(v,u);
}
dep[rt]=1;
fa[rt]=rt;
dfs1(rt);
dfs2(rt,rt);
build(1,1,n);
while(m--)
{
int opt=read();
if(opt==1)
{
int x=read(),y=read(),z=read();
cal1(x,y,z%mod);
}
else if(opt==2)
{
int x=read(),y=read();
write(cal2(x,y)%mod);
printf("\n");
}
else if(opt==3)
{
int x=read(),z=read();
update(1,1,tot,dl[x],dl[x]+size[x]-1,z%mod);
}
else
{
int x=read();
write(query(1,1,tot,dl[x],dl[x]+size[x]-1)%mod);
printf("\n");
}
}
return 0;
}
相关题目
长链剖分
咕咕咕