树链剖分

定义

把树剖成一条条不相交的链,对树的操作就转化成了对链的操作

概念

重儿子:对于每一个非叶子节点,它的儿子中 以那个儿子为根的子树节点数最大的儿子 为该节点的重儿子
轻儿子:对于每一个非叶子节点,它的儿子中 非重儿子 的剩下所有儿子即为轻儿子

重边:连接任意两个重儿子的边叫做重边
轻边:剩下的即为轻边

重链:相邻重边连起来的 连接一条重儿子 的链叫重链

对于叶子节点,若其为轻儿子,则有一条以自己为起点的长度为1的链
每一条重链以轻儿子为起点

链头:一条重链上深度最小的点

剖树

dfs1

1.记录深度
2.记录父亲
3.记录子树大小,包括它自己
4.记录重儿子

void dfs1(int x,int father){
	deep[x]=deep[father]+1;//记录深度
	fa[x]=father;//记录父亲
	siz[x]=1;//记录子树大小,包括它自己
	for(int i=head[x];~i;i=e[i].nxt){
		int y=e[i].to;
		if(y!=father){
			dfs1(y,x);
			siz[x]+=siz[y];
			if(!son[x]||siz[son[x]]<siz[y]){//记录重儿子
				son[x]=y;
			}
		}
	}
}

复杂度:O(n)

dfs2

1.记录新编号
2.记录链头

void dfs2(int x,int topx){
        id[x]=++num;//记录新编号
	top[x]=topx;//记录链头
	if(!son[x])return;//叶子节点
	dfs2(son[x],topx);//先dfs重儿子
	for(int i=head[x];~i;i=e[i].nxt){
		int y=e[i].to;
		if(y!=son[x]&&y!=fa[x]){
			dfs2(y,y);//后dfs轻儿子
		}
	}
}

复杂度:O(n)

LCA

上一篇LCA的博客里说要讲树剖版LCA,概念不在此赘述

剖分后如何求LCA?

分两种情况:

  1. x,y 在同一条重链上,因为一条重链上的点都是祖先和后代的关系,于是深度较浅的点即为LCA

  2. x,y 不在同一条重链上,让 x,y 沿着重链往上跳,直到跳到同一条重链上

int lca(int x,int y){
	while(top[x]!=top[y]){
		if(deep[top[x]]<deep[top[y]])swap(x,y);
		x=fa[top[x]];
	}
	return deep[x]<deep[y]?x:y;
}

复杂度:O(logn) ,查询 m 次 总复杂度:O(mlogn)

#include <bits/stdc++.h>
using namespace std;
const int N=500005;
int n,m,s;
int head[N*2];
struct node{
	int to,nxt;
}e[N*2];
int cnt;
void add(int u,int v){
	e[cnt].to=v;
	e[cnt].nxt=head[u];
	head[u]=cnt++;
}
int deep[N],siz[N],fa[N],son[N],top[N];
void dfs1(int x,int father){
	deep[x]=deep[father]+1;
	fa[x]=father;
	siz[x]=1;
	for(int i=head[x];~i;i=e[i].nxt){
		int y=e[i].to;
		if(y!=father){
			dfs1(y,x);
			siz[x]+=siz[y];
			if(!son[x]||siz[son[x]]<siz[y]){
				son[x]=y;
			}
		}
	}
}
void dfs2(int x,int topx){
	top[x]=topx;
	if(!son[x])return;
	dfs2(son[x],topx);
	for(int i=head[x];~i;i=e[i].nxt){
		int y=e[i].to;
		if(y!=son[x]&&y!=fa[x]){
			dfs2(y,y);
		}
	}
}
int lca(int x,int y){
	while(top[x]!=top[y]){
		if(deep[top[x]]<deep[top[y]])swap(x,y);
		x=fa[top[x]];
	}
	return deep[x]<deep[y]?x:y;
}
int main(){
	for(int i=0;i<N*2;i++){
		head[i]=-1;
		e[i].nxt=-1;
	}
	cin>>n>>m>>s;
	for(int i=1;i<n;i++){
		int a,b;
		cin>>a>>b;
		add(a,b),add(b,a);
	}
	dfs1(s,0);
	dfs2(s,s);
	for(int i=1;i<=m;i++){
		int a,b;
		cin>>a>>b;
		cout<<lca(a,b)<<endl;
	}
	return 0;
}

DFS序

大家可能会发现,LCA中并没有用到 id[] ,即dfs序,这就涉及到重链的一个重要的性质

一条重链内的dfs序是连续的,一个子树内的dfs序是连续的

也就是说,如果用dfs序标记重链的节点,这条重链就变成了连续的数字,那么在重链上的区间问题便可以用线段树来维护

应用

洛谷P3384

  1. 将树从 x 到 y 结点最短路径上所有节点的值都加上 z 。

  2. 求树从 x 到 y 结点最短路径上所有节点的值之和。

  3. 将以 x 为根节点的子树内所有节点值都加上 z 。

  4. 求以 x 为根节点的子树内所有节点值之和 。

操作一

  1. 将树从 x 到 y 结点最短路径上所有节点的值都加上 z 。

     x 到 y 的最短路径经过 lca(x,y) ,这个过程实际上就是查找lca 
    
void updatel(ll x,ll y,ll z){//过程类lca
	while(top[x]!=top[y]){
		if(deep[top[x]]<deep[top[y]])swap(x,y);
		update(1,id[top[x]],id[x],1,n,z);//修改一条重链内部
		x=fa[top[x]];//跳过一条轻边,到达上一条重链
	}
	if(deep[x]>deep[y])swap(x,y);
	update(1,id[x],id[y],1,n,z);//当 x,y 到达同一条重链上,修改 x,y 之间的部分
}

操作二

  1. 求树从 x 到 y 结点最短路径上所有节点的值之和。

     把操作一的修改变成查询就好了
    
ll queryl(ll x,ll y){
	ll an=0;
	while(top[x]!=top[y]){
		if(deep[top[x]]<deep[top[y]])swap(x,y);
		an+=query(1,id[top[x]],id[x],1,n);
		an%=mod;
		x=fa[top[x]];
	}
	if(deep[x]>deep[y])swap(x,y);
	an+=query(1,id[x],id[y],1,n);
	return an%mod;
}

操作三

  1. 将以 x 为根节点的子树内所有节点值都加上 z 。

      一个子树内部dfs序连续
    
void updatet(ll x,ll z){
	update(1,id[x],id[x]+siz[x]-1,1,n,z);
}

操作四

  1. 求以 x 为根节点的子树内所有节点值之和 。

     把操作三的修改变成查询就好了
    
ll queryt(ll x){
	return query(1,id[x],id[x]+siz[x]-1,1,n)%mod;
}

总代码

总代码就是线段树和树链部分的简单堆砌

#include <bits/stdc++.h>
using namespace std;
#define ll long long
ll ls(int x){return x<<1;}
ll rs(int x){return x<<1|1;}
const int N=100005;
ll n,m,r,mod;
ll a[N];
struct node{
	ll to,nxt;
}e[N<<1];
ll head[N<<1];
ll cnt;
void add(ll u,ll v){
	e[cnt].to=v;
	e[cnt].nxt=head[u];
	head[u]=cnt++;
}
ll deep[N],fa[N],son[N],siz[N];
void dfs1(ll x,ll father){
	deep[x]=deep[father]+1;
	fa[x]=father;
	siz[x]=1;
	for(int i=head[x];~i;i=e[i].nxt){
		ll y=e[i].to;
		if(y!=father){
			dfs1(y,x);
			siz[x]+=siz[y];
			if(!son[x]||siz[y]>siz[son[x]]){
				son[x]=y;
			}
		}
	}
}
ll top[N],id[N],ans,w[N];
void dfs2(ll x,ll topx){
	top[x]=topx;
	id[x]=++ans;
	w[ans]=a[x];
	if(!son[x])return;
	dfs2(son[x],topx);
	for(int i=head[x];~i;i=e[i].nxt){
		ll y=e[i].to;
		if(y!=fa[x]&&y!=son[x]){
			dfs2(y,y);
		}
	}
}
ll tree[N<<2],tag[N<<2];
void push_up(ll p){
	tree[p]=tree[ls(p)]+tree[rs(p)];
	tree[p]%=mod;
}
void build(ll p,ll pl,ll pr){
	tag[p]=0;
	if(pl==pr){
		tree[p]=w[pl];
		tree[p]%=mod;
		return;
	}
	ll mid=(pl+pr)>>1;
	build(ls(p),pl,mid);
	build(rs(p),mid+1,pr);
	push_up(p);
}
void addtag(ll p,ll pl,ll pr,ll d){
	tag[p]+=d;
	tree[p]+=d*(pr-pl+1);
	tree[p]%=mod;
}
void push_down(ll p,ll pl,ll pr){
	if(tag[p]){
		ll mid=(pl+pr)>>1;
		addtag(ls(p),pl,mid,tag[p]);
		addtag(rs(p),mid+1,pr,tag[p]);
		tag[p]=0;
	}
}
void update(ll p,ll L,ll R,ll pl,ll pr,ll d){
	if(L<=pl&&R>=pr){
		addtag(p,pl,pr,d);
		return;
	}
	push_down(p,pl,pr);
	ll mid=(pl+pr)>>1;
	if(L<=mid)update(ls(p),L,R,pl,mid,d);
	if(R>mid)update(rs(p),L,R,mid+1,pr,d);
	push_up(p);
}
void updatel(ll x,ll y,ll z){
	while(top[x]!=top[y]){
		if(deep[top[x]]<deep[top[y]])swap(x,y);
		update(1,id[top[x]],id[x],1,n,z);
		x=fa[top[x]];
	}
	if(deep[x]>deep[y])swap(x,y);
	update(1,id[x],id[y],1,n,z);
} 
ll query(ll p,ll L,ll R,ll pl,ll pr){
	ll res=0;
	if(L<=pl&&R>=pr){
		return tree[p]%=mod;
	}
	push_down(p,pl,pr);
	ll mid=(pl+pr)>>1;
	if(L<=mid)res+=query(ls(p),L,R,pl,mid);
	if(R>mid)res+=query(rs(p),L,R,mid+1,pr);
	return res;
}
ll queryl(ll x,ll y){
	ll an=0;
	while(top[x]!=top[y]){
		if(deep[top[x]]<deep[top[y]])swap(x,y);
		an+=query(1,id[top[x]],id[x],1,n);
		an%=mod;
		x=fa[top[x]];
	}
	if(deep[x]>deep[y])swap(x,y);
	an+=query(1,id[x],id[y],1,n);
	return an%mod;
} 
void updatet(ll x,ll z){
	update(1,id[x],id[x]+siz[x]-1,1,n,z);
}
ll queryt(ll x){
	return query(1,id[x],id[x]+siz[x]-1,1,n)%mod;
}
int main(){
	for(int i=0;i<N*2;i++){
		e[i].nxt=-1;
		head[i]=-1;
	}
	cin>>n>>m>>r>>mod;
	for(int i=1;i<=n;i++){
		cin>>a[i];
	}
	for(int i=1;i<n;i++){
		ll u,v;
		cin>>u>>v;
		add(u,v);
		add(v,u);
	}
	dfs1(r,0);
	dfs2(r,r);
	build(1,1,n);
	for(int i=1;i<=m;i++){
		ll o,x,y,z;
		cin>>o;
		if(o==1){
			cin>>x>>y>>z;
			updatel(x,y,z);
		}else if(o==2){
			cin>>x>>y;
			cout<<queryl(x,y)%mod<<endl;
		}else if(o==3){
			cin>>x>>z;
			updatet(x,z);
		}else{
			cin>>x;
			cout<<queryt(x)%mod<<endl;
		}
	}
	return 0;
}
posted @ 2024-08-06 19:35  小惰惰  阅读(16)  评论(0编辑  收藏  举报
/* 鼠标点击求赞文字特效 */