洛谷 P3384 【模板】轻重链剖分

树链剖分模板。

给一个 nn 结点的树(连通且无环),每个结点上有一个数值,进行如下操作:

  1. 结点 xx 到结点 yy 的最短路径上所有结点的值 +z+z
  2. 求结点 xx 到结点 yy 的最短路径上的所有结点值的和
  3. 将以结点 xx 为根结点的子树中所有结点值 +z+z
  4. 求以结点 xx 为根结点的子树内所有结点值之和

树链剖分详解(洛谷模板 P3384)

树链剖分oiwiki

上面两个链接把树链剖分都讲得很好,我在这里做一下复读机…

树链剖分的目的是把一棵树分成许多链,减少一些问题的处理难度,通常意义上的树链剖分指的是重链剖分(zhong 轻重的重),对应英文为 heavy path decomposition/heavy-light decomposition

其中一个结点的重儿子指的是该儿子形成的子树是所有儿子总最大的(多个最大则随便取),剩下的叫轻儿子。一个结点连接它重儿子的边叫做重边,其他的叫做轻边。从一个轻儿子开始,不断地重复连接他的重儿子,形成重链,若叶节点为轻儿子,则它自己形成一条重链,示意图如下(取自 oiwiki):

在这里插入图片描述
用两个 dfs 来得到一些需要的信息。

第一个 dfs 得到每个结点的深度 dep[u],父结点编号 fa[u] ,该结点形成的子树大小 sz[u],该结点的重儿子编号 son[u] ,则代码如下:

void dfs1(int u,int f,int d){
	dep[u]=d;fa[u]=f;sz[u]=1;
	int msz=-1;
	for(int i=head[u];i!=-1;i=edge[i].nxt){
		int v=edge[i].to;if(v==f)continue;
		dfs1(v,u,d+1);sz[u]+=sz[v];
		if(sz[v]>msz)msz=sz[v],son[u]=v;
	}
}

第二个 dfs 要进行重边优先遍历,并记录对应的 dfs 序 dfn,以及 dfn 编号所对应的结点 rktop 记录结点的链顶。

void dfs2(int u,int t){
	dfn[u]=++tot;rk[tot]=u;top[u]=t;
	if(!son[u])return;
	dfs2(son[u],t);
	for(int i=head[u];i!=-1;i=edge[i].nxt){
		int v=edge[i].to;
		if(v==son[u]||v==fa[u])continue;
		dfs2(v,v);
	}
}

其中重边优先遍历之后会发现生成的 dfs 序 dfn 有以下两个特性:

  1. 每条重链内的 dfn 是连续的
  2. 每个子树内的 dfn 也是连续的

处理两点 u,v 间路径和时,假设 dep[top[u]]>dep[top[v]],也就是说 u 的链顶更深,首先 res 加上 utop[u] 的点权和,然后 u 移动到 fa[top[u]] ,重复执行操作,直到 u,v 在一个链上,此时再加上两点的区间和,更新时类似。

int qRange(int u,int v){
	res=0;
	while(top[u]!=top[v]){
		if(dep[top[u]]<dep[top[v]])swap(u,v);
		res+=query(dfn[top[u]],dfn[u],1,n,1);
		res%=MOD;u=fa[top[u]];
	}
	if(dep[u]>dep[v])swap(u,v);
	res+=query(dfn[u],dfn[v],1,n,1);
	return res%MOD;
}
void updRange(int u,int v,int k){
	k%=MOD;
	while(top[u]!=top[v]){
		if(dep[top[u]]<dep[top[v]])swap(u,v);
		update(dfn[top[u]],dfn[u],1,n,1,k);
		u=fa[top[u]];
	}
	if(dep[u]>dep[v])swap(u,v);
	update(dfn[u],dfn[v],1,n,1,k);
}

处理一点的子树点权和,之前已经记录了每个点的子树大小,又因为子树内编号连续,因此可以直接操作:

int qSon(int u){
	return query(dfn[u],dfn[u]+sz[u]-1,1,n,1);
}
void updSon(int u,int k){
	update(dfn[u],dfn[u]+sz[u]-1,1,n,1,k);
}

query,update 用线段树来实现。

附:假如求 LCA (最近公共祖先)的话,从 top 较深的点开始不断向上跳,直到两个点在一个链上,深度较小的即为 LCA

int lca(int u,int v){
	while(top[u]!=top[v]){
		if(dep[top[u]]<dep[top[v]])swap(u,v);
		u=fa[top[u]];
	}
	return dep[u]>dep[v]?v:u;
}

整体代码如下:

#include<iostream>
#include<cstdio>
#include<cstring>
#define MAXN 100010
#define MAXM 200010
#define lc(p) (p<<1)
#define rc(p) (p<<1|1)
using namespace std;
typedef long long ll;
ll res[MAXN<<2],tag[MAXN<<2],k,z;
int head[MAXN],n,m,r,M,a[MAXN],u,v,dep[MAXN],fa[MAXN],sz[MAXN];
int son[MAXN],dfn[MAXN],rk[MAXN],top[MAXN],t,x,y,tot,tt;
struct Edge{
    int to,nxt;
}edge[MAXM];
void addedge(int u,int v){
    edge[tot].to=v;edge[tot].nxt=head[u];
    head[u]=tot++;
}
void dfs1(int u,int f,int d){
    dep[u]=d;fa[u]=f;sz[u]=1;int msz=-1;
    for(int i=head[u];i!=-1;i=edge[i].nxt){
        int v=edge[i].to;if(v==f)continue;
        dfs1(v,u,d+1);sz[u]+=sz[v];
        if(sz[v]>msz)msz=sz[v],son[u]=v;
    }
}
void dfs2(int u,int t){
    dfn[u]=++tt;rk[tt]=u;top[u]=t;
    if(!son[u])return;dfs2(son[u],t);
    for(int i=head[u];i!=-1;i=edge[i].nxt){
        int v=edge[i].to;
        if(v==fa[u]||v==son[u])continue;
        dfs2(v,v);
    }
}
void pushup(int p){res[p]=(res[lc(p)]+res[rc(p)])%M;}
void f(int l,int r,int p,ll k){
    res[p]+=k*(r-l+1);tag[p]+=k;
}
void pushdown(int l,int r,int p){
    int mid=l+(r-l)/2;
    f(l,mid,lc(p),tag[p]);f(mid+1,r,rc(p),tag[p]);
    tag[p]=0;
}
void build(int l,int r,int p){
    if(l==r){
        res[p]=a[rk[l]]%M;
        return;
    }
    int mid=l+(r-l)/2;
    build(l,mid,lc(p));build(mid+1,r,rc(p));
    pushup(p);
}
void update(int x,int y,int l,int r,int p,ll k){
    if(x<=l&&r<=y){
        res[p]+=k*(r-l+1);res[p]%=M;
        tag[p]+=k;
        return;
    }
    pushdown(l,r,p);
    int mid=l+(r-l)/2;
    if(x<=mid)update(x,y,l,mid,lc(p),k);
    if(y>mid)update(x,y,mid+1,r,rc(p),k);
    pushup(p);
}
ll query(int x,int y,int l,int r,int p){
    if(x<=l&&r<=y){
        return res[p];
    }
    pushdown(l,r,p);
    int mid=l+(r-l)/2;ll res=0;
    if(x<=mid)res+=query(x,y,l,mid,lc(p));
    if(y>mid)res+=query(x,y,mid+1,r,rc(p));
    return res;
}
int qRange(int u,int v){
    int res=0;
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]])swap(u,v);
        res+=query(dfn[top[u]],dfn[u],1,tt,1);
        res%=M;u=fa[top[u]];
    }
    if(dep[u]>dep[v])swap(u,v);
    res+=query(dfn[u],dfn[v],1,tt,1);
    return res%M;
}
void updRange(int u,int v,ll k){
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]])swap(u,v);
        update(dfn[top[u]],dfn[u],1,tt,1,k);
        u=fa[top[u]];
    }
    if(dep[u]>dep[v])swap(u,v);
    update(dfn[u],dfn[v],1,tt,1,k);
}
int qSon(int u){
    return query(dfn[u],dfn[u]+sz[u]-1,1,tt,1)%M;
}
void updSon(int u,ll k){
    update(dfn[u],dfn[u]+sz[u]-1,1,tt,1,k);
}
int main(){
#ifdef WINE
    freopen("data.in","r",stdin);
#endif
    memset(head,-1,sizeof(head));tot=0;tt=0;
    scanf("%d%d%d%d",&n,&m,&r,&M);
    for(int i=1;i<=n;i++)scanf("%d",&a[i]);
    for(int i=1;i<n;i++){
        scanf("%d%d",&u,&v);
        addedge(u,v);addedge(v,u);
    }
    dfs1(r,0,1);
    dfs2(r,r);
    build(1,tt,1);
    while(m--){
        scanf("%d%d",&t,&x);
        if(t==1){
            scanf("%d%lld",&y,&z);
            updRange(x,y,z);
        }else if(t==2){
            scanf("%d",&y);
            printf("%d\n",qRange(x,y));
        }else if(t==3){
            scanf("%lld",&z);
            updSon(x,z);
        }else printf("%d\n",qSon(x));
    }
    return 0;
}

在这里插入图片描述

posted @ 2020-04-11 11:52  winechord  阅读(159)  评论(0编辑  收藏  举报