树链剖分板子

树链剖分,顾名思义,将一个树以链的形式分割同时用数据结构将每个链维护起来,以方便树上操作。
我们当前一般说的树链剖分即为重链剖分
所以这里大致说一下重链剖分的定义及步骤。
对于一个根节点,我们定义,其子节点对应的子树所包含的节点数最多的那个节点被称为根节点的重儿子。
我们以此将整个树的所有分支都分为两种,一种是重边连向重儿子,反之为轻边。
我们定义一个根节点向上无重边,那么以其为起点向下找重边及重儿子构成的链称之为重链。
没有重边重儿子的叶子节点若上方非重边相连,它自己构成一个重链。
那么我们可以想到,因为一个树上出了重边就是轻边,重边会将路径上的所有节点包含在重链中,轻边下面自成新的重链。
所以整个树就被重链完全不重合地分割成了几个不同的链。
在代码中实现重链剖分,dfs明显必须,但是我们要维护的信息有哪些呢。
首先明显的是重儿子信息,父节点信息,重儿子的获得需要知道子树大小,所以size必须有。
但是考虑到一个树它可能给我们的是无向图,不知父子关系。
所以要维护新的节点编号,这就要求维护新编号信息,对应新编号对应的旧编号信息。
还因为我们是要通过链划分,将树套到数据结构里面,所以一个链的起止点也是必须的。
而我们递归至最深层节点就是链的结尾点,所以直接存重链起始点,在递归最深层建数据结构就可以了。
我们通过如下的两次dfs得到树的相关信息。

void get_son(int now)
{
	size[now]=1;
	for(int i=head[now];i;i=edge[i].nx)
	{
		int t=edge[i].t;
		if(t!=fa[now])
			fa[t]=now,dfs(t),
			if(size[son[now]]<size[t])son[now]=t;
			size[now]+=size[t];
	}
}
void dfs(int now)
{
	id[now]=++tot;
	rk[tot]=now;
	if(!son[now])return;//这里就建数据结构。
	top[son[now]]=top[now],dfs(son[now]);
	for(int i=head[now];i;i=edge[i].nx)
	{
		int t=edge[i].t;
		if(t!=son[now]&&t!=fa[now])
			top[t]=t,dfs(t);
	}
}

板子题

#include<bits/stdc++.h>
#define ll long long
#define qr qr()
#define pa pair<int,int>
#define fr first
#define sc second
#define lc tree[rt].ls
#define rc tree[rt].rs
using namespace std;
const int N=2e5+200;
int m,n,nd_tot,mod,tot,nd_rt[N],dep[N],cnt,num[N],head[N],top[N],id[N],rk[N],size[N],son[N],fa[N];
inline int qr
{
	int x=0;char ch=getchar();bool f=0;
	while(ch>57||ch<48)
	{
		if(ch=='-')f=1;
		ch=getchar();
	}
	while(ch<=57&&ch>=48)x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
	return f?-x:x;
}
struct node{
	int t,nx;
}edge[N];
struct What_can_I_say{
	int ls,rs,l,r;
	ll sum,mx;
}tree[N<<2];
inline void add(int f,int t)
{
	edge[++tot]={t,head[f]};
	head[f]=tot;
}
inline void pushup(int rt)
{
	tree[rt].sum=tree[tree[rt].ls].sum+tree[tree[rt].rs].sum;
	tree[rt].mx=max(tree[tree[rt].ls].mx,tree[tree[rt].rs].mx);
}
void build(int &rt,int l,int r)
{

	rt=++cnt;
	tree[rt].l=l;
	tree[rt].r=r;
	if(l==r)
	{
		tree[rt].sum=num[id[l]];
		tree[rt].mx=num[id[l]];
		return;
	}
	int md=(l+r)/2;
	build(lc,l,md);
	build(rc,md+1,r);
	pushup(rt);
}
void update(int rt,int st,int pos)
{
	if(tree[rt].l==tree[rt].r)
	{
		tree[rt].sum=pos;
		tree[rt].mx=pos;
		return;
	}
	if(tree[lc].r>=st)update(lc,st,pos);
	else update(rc,st,pos);
	pushup(rt);
}
void get_son(int now)
{
	size[now]=1;
	for(int i=head[now];i;i=edge[i].nx)
	{
		int t=edge[i].t;
		if(t!=fa[now])
		{
			fa[t]=now,dep[t]=dep[now]+1,get_son(t);
			if(size[son[now]]<size[t])son[now]=t;
			size[now]+=size[t];
		}
	}
}
void dfs(int now)
{
	rk[now]=++nd_tot;
	id[nd_tot]=now;
	if(!son[now])
	{
		build(nd_rt[top[now]],rk[top[now]],rk[now]);
		return;
	}
	top[son[now]]=top[now],dfs(son[now]);
	for(int i=head[now];i;i=edge[i].nx)
	{
		int t=edge[i].t;
		if(t!=son[now]&&t!=fa[now])
			top[t]=t,dfs(t);
	}
}
int get_max(int rt,int l,int r)
{
	if(tree[rt].l>=l&&tree[rt].r<=r)
		return tree[rt].mx;
	int ans=0xcfffffff;
	if(tree[lc].r>=l)ans=max(ans,get_max(lc,l,r));
	if(tree[lc].r<r)ans=max(ans,get_max(rc,l,r));
	return ans;
}
int get_sum(int rt,int l,int r)
{
	if(tree[rt].l>=l&&tree[rt].r<=r)
		return tree[rt].sum;
	int ans=0;
	if(tree[lc].r>=l) ans+=get_sum(lc,l,r);
	if(tree[lc].r<r) ans+=get_sum(rc,l,r);
	return ans;
}
int ask(int op,int a,int b)
{
	if(op&1)
	{
		int ans=0xcfffffff;
		while(1)
		{
			if(dep[top[a]]<dep[top[b]])swap(a,b);
			if(top[a]!=top[b])
			{
				ans=max(ans,get_max(nd_rt[top[a]],rk[top[a]],rk[a]));
				a=fa[top[a]];
			}
			else
			{
				if(rk[a]<rk[b])swap(a,b);
				ans=max(ans,get_max(nd_rt[top[a]],rk[b],rk[a]));
				return ans;
			}
		}
	}
	else
	{
		int ans=0;
		while(1)
		{
			if(dep[top[a]]<dep[top[b]])swap(a,b);
			if(top[a]^top[b])
			{
				ans+=get_sum(nd_rt[top[a]],rk[top[a]],rk[a]);
				a=fa[top[a]];
			}
			else
			{
				if(rk[a]<rk[b])swap(a,b);
				ans+=get_sum(nd_rt[top[a]],rk[b],rk[a]);
				return ans;
			}
		}
	}
}
void init()
{
	n=qr;
	for(int i=1;i<n;++i)
	{
		int f=qr,t=qr;
		add(f,t);
		add(t,f);
	}
	for(int i=1;i<=n;++i)
		num[i]=qr;
	m=qr;
	char op[50];
	top[1]=1;
	get_son(1);
	dfs(1);
	for(int i=1;i<=m;++i)
	{
		cin>>op+1;
		int a=qr,b=qr;
		if(op[2]=='M')
			printf("%d\n",ask(1,a,b));
		else if(op[2]=='S')
			printf("%d\n",ask(2,a,b));
		else
			update(nd_rt[top[a]],rk[a],b);
	}
}
int main()
{
	#ifndef ONLINE_JUDGE
	freopen("in.in","r",stdin);
	freopen("out.out","w",stdout);
	#endif
	init();
	return 0;
}
posted @ 2024-05-17 16:12  SLS-wwppcc  阅读(8)  评论(0编辑  收藏  举报