【学习笔记】树链剖分

树链剖分——轻重链剖分

其实是复习啦

引入

先解释一些变量和名词,防止后面混淆

id DFS序

deep 深度

fa 父节点

val 点的权值

siz 子树大小

max_son 重儿子

top 一条链的顶端

重儿子 子树最大的儿子

轻儿子 !重儿子(众所周知!是非的意思)

重边 父节点与重儿子的连边

轻边 父节点与轻儿子的连边

重链 好几条重边连成的路径

轻链 好几条轻边连成的路径

问题描述

给定一棵树,进行如下操作

将树从 x 到 y 结点最短路径上所有节点的值都加上 z

求树从 x 到 y 结点最短路径上所有节点的值之和

将以 x 为根节点的子树内所有节点值都加上 z

求以 x 为根节点的子树内所有节点值之和

思想概括

看到问题,我们不难想到树上差分和LCA

其实树剖大致上就是用把树搞成链,然后用数据结构维护起来

而进行操作时向上跳的不是节点,而是一整条链

预处理

DFS预处理出max_son和DFS序等信息

第一遍DFS,预处理重儿子、深度、子树大小和父节点

void dfs_first(int x,int y)
{
	siz[x]=1;
	deep[x]=deep[y]+1;
	fa[x]=y;
	for(int i=head[x];i;i=e[i].next)
	{
		int v=e[i].to;
		if(v==y)
		{
			continue;
		}
		dfs_first(v,x);
		siz[x]+=siz[v];
		if(siz[v]>siz[max_son[x]])
		{
			max_son[x]=v;
		}
	}
}

第二遍DFS,预处理出DFS序并且划分成链(预处理top)

void dfs_second(int x,int y) //这里的y是top呦
{
	top[x]=y;
	id[x]=++tot;
	e[tot].val=val[x];
	if(!max_son[x]) //为了更好地用数据结构维护,按轻重链划分
	{
		return ;
	}
	dfs_second(max_son[x],top[x]);
	for(int i=head[x];i;i=e[i].next)
	{
		int v=e[i].to;
		if(v!=fa[x] && v!=max_son[x])
		{
			dfs_second(v,v); //划分新链,这个节点的top是它自己
		}
	}
}

数据结构维护

我用的是线段树啦

其实就是板子

void build(int p,int l,int r)
{
	t[p].l=l;
	t[p].r=r;
	t[p].add=0;
	if(l==r)
	{
		t[p].val=e[l].val;
		
		if(t[p].val>mod)
		{
			t[p].val%=mod;	
		}
		return ;
	}
	int mid=(l+r)>>1;
	build(p*2,l,mid);
	build(p*2+1,mid+1,r);
	t[p].val=(t[p*2].val+t[p*2+1].val)%mod;
}

void push_down(int p)
{
	if(t[p].add)
	{
		t[p*2].val+=(t[p*2].r-t[p*2].l+1)*t[p].add;
		t[p*2+1].val+=(t[p*2+1].r-t[p*2+1].l+1)*t[p].add;
		t[p*2].add+=t[p].add;
		t[p*2+1].add+=t[p].add;
		t[p*2].val%=mod;
		t[p*2+1].val%=mod;
		t[p].add=0;
	}
}

void update(int p,int l,int r,int k)
{
	if(l<=t[p].l && t[p].r<=r)
	{
		t[p].val+=(t[p].r-t[p].l+1)*k;
		t[p].add+=k;
		return ;
	}
	
	push_down(p);
	
	int mid=(t[p].l+t[p].r)>>1;
	if(l<=mid)
	{
		update(p*2,l,r,k);
	}
	if(r>mid)
	{
		update(p*2+1,l,r,k);
	}
	t[p].val=(t[p*2].val+t[p*2+1].val)%mod;
}

int query(int p,int l,int r)
{
	if(l<=t[p].l && t[p].r<=r)
	{
		return t[p].val;
	}
	push_down(p);
	int mid=(t[p].l+t[p].r)>>1;
	int sum=0;
	if(l<=mid)
	{
		sum+=query(p*2,l,r);
		sum%=mod;
	}
	if(r>mid)
	{
		sum+=query(p*2+1,l,r);
		sum%=mod;
	}
	return sum;
}

维护的信息视情况而定

查询/修改操作

类似求LCA

查询

void get_sum(int x,int y)
{
	int ans=0;
	while(top[x]!=top[y])
	{
		if(deep[top[x]]<deep[top[y]])
		{
			swap(x,y);
		}
		ans+=query(1,id[top[x]],id[x]);
		ans%=mod;
		x=fa[top[x]];
	}
	if(deep[x]>deep[y])
	{
		swap(x,y);
	}
	ans+=query(1,id[x],id[y]);
	cout<<ans%mod<<endl;
}

修改

void change(int x,int y,int z)
{
	z%=mod;
	while(top[x]!=top[y])
	{
		if(deep[top[x]]<deep[top[y]])
		{
			swap(x,y);
		}
		update(1,id[top[x]],id[x],z);
		x=fa[top[x]];
	}
	
	if(deep[x]>deep[y])
	{
		swap(x,y);
	}
	update(1,id[x],id[y],z);
}

对子树进行操作

根据DFS预处理DFS序进行

节点x的子树在维护时的范围是id[x]到id[x]+siz[x]-1

然后就直接用线段树来搞啦

void change_son(int x,int y)
{
	update(1,id[x],id[x]+siz[x]-1,y);
}

void sum_son(int x)
{
	cout<<query(1,id[x],id[x]+siz[x]-1)<<endl;
}

模板の代码

点击查看代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<cmath>
#include<algorithm>
#define int long long

using namespace std;

const int maxn=1e5+5;

inline int read()
{
	int w=0,f=1;
	char ch=getchar();
	while(ch<'0' || ch>'9')
	{
		if(ch=='-')	
		{
			f=-1;
		}
		ch=getchar();
	}
	while(ch>='0' && ch<='9')
	{
		w=(w<<3)+(w<<1)+(ch^48);
		ch=getchar();
	}
	return w*f;
}

int n,m,r,mod;

int tot;

int head[maxn];

struct s_t
{
	int l;
	int r;
	int val;
	int add;
}t[maxn*4];

struct edge
{
	int to;
	int next;
	int val;
}e[maxn*2];

int top[maxn];
int max_son[maxn];
int id[maxn];
int val[maxn];
int deep[maxn];
int siz[maxn];
int fa[maxn];

void add(int x,int y)
{
	tot++;
	e[tot].to=y;
	e[tot].next=head[x];
	head[x]=tot;
}

void dfs_first(int x,int y)
{
	siz[x]=1;
	deep[x]=deep[y]+1;
	fa[x]=y;
	for(int i=head[x];i;i=e[i].next)
	{
		int v=e[i].to;
		if(v==y)
		{
			continue;
		}
		dfs_first(v,x);
		siz[x]+=siz[v];
		if(siz[v]>siz[max_son[x]])
		{
			max_son[x]=v;
		}
	}
}

void dfs_second(int x,int y)
{
	top[x]=y;
	id[x]=++tot;
	e[tot].val=val[x];
	if(!max_son[x])
	{
		return ;
	}
	dfs_second(max_son[x],top[x]);
	for(int i=head[x];i;i=e[i].next)
	{
		int v=e[i].to;
		if(v!=fa[x] && v!=max_son[x])
		{
			dfs_second(v,v);
		}
	}
}

void build(int p,int l,int r)
{
	t[p].l=l;
	t[p].r=r;
	t[p].add=0;
	if(l==r)
	{
		t[p].val=e[l].val;
		
		if(t[p].val>mod)
		{
			t[p].val%=mod;	
		}
		return ;
	}
	int mid=(l+r)>>1;
	build(p*2,l,mid);
	build(p*2+1,mid+1,r);
	t[p].val=(t[p*2].val+t[p*2+1].val)%mod;
}

void push_down(int p)
{
	if(t[p].add)
	{
		t[p*2].val+=(t[p*2].r-t[p*2].l+1)*t[p].add;
		t[p*2+1].val+=(t[p*2+1].r-t[p*2+1].l+1)*t[p].add;
		t[p*2].add+=t[p].add;
		t[p*2+1].add+=t[p].add;
		t[p*2].val%=mod;
		t[p*2+1].val%=mod;
		t[p].add=0;
	}
}

void update(int p,int l,int r,int k)
{
	if(l<=t[p].l && t[p].r<=r)
	{
		t[p].val+=(t[p].r-t[p].l+1)*k;
		t[p].add+=k;
		return ;
	}
	
	push_down(p);
	
	int mid=(t[p].l+t[p].r)>>1;
	if(l<=mid)
	{
		update(p*2,l,r,k);
	}
	if(r>mid)
	{
		update(p*2+1,l,r,k);
	}
	t[p].val=(t[p*2].val+t[p*2+1].val)%mod;
}

int query(int p,int l,int r)
{
	if(l<=t[p].l && t[p].r<=r)
	{
		return t[p].val;
	}
	push_down(p);
	int mid=(t[p].l+t[p].r)>>1;
	int sum=0;
	if(l<=mid)
	{
		sum+=query(p*2,l,r);
		sum%=mod;
	}
	if(r>mid)
	{
		sum+=query(p*2+1,l,r);
		sum%=mod;
	}
	return sum;
}

void change(int x,int y,int z)
{
	z%=mod;
	while(top[x]!=top[y])
	{
		if(deep[top[x]]<deep[top[y]])
		{
			swap(x,y);
		}
		update(1,id[top[x]],id[x],z);
		x=fa[top[x]];
	}
	
	if(deep[x]>deep[y])
	{
		swap(x,y);
	}
	update(1,id[x],id[y],z);
}

void get_sum(int x,int y)
{
	int ans=0;
	while(top[x]!=top[y])
	{
		if(deep[top[x]]<deep[top[y]])
		{
			swap(x,y);
		}
		ans+=query(1,id[top[x]],id[x]);
		ans%=mod;
		x=fa[top[x]];
	}
	if(deep[x]>deep[y])
	{
		swap(x,y);
	}
	ans+=query(1,id[x],id[y]);
	cout<<ans%mod<<endl;
}

void change_son(int x,int y)
{
	update(1,id[x],id[x]+siz[x]-1,y);
}

void sum_son(int x)
{
	cout<<query(1,id[x],id[x]+siz[x]-1)<<endl;
}

signed main()
{
	n=read();
	m=read();
	r=read();
	mod=read();
	for(int i=1;i<=n;i++)
	{
		val[i]=read();
	}
	int x;
	int y;
	for(int i=1;i<n;i++)
	{
		int x=read();
		int y=read();
		add(x,y);
		add(y,x);
	}
	
	tot=0;
	dfs_first(r,0);
	tot=0;
	dfs_second(r,r);
	build(1,1,n);
	int z;
	for(int i=1;i<=m;i++)
	{
		int opt=read();
		if(opt==1)
		{
			x=read();
			y=read();
			z=read();
			change(x,y,z);
		}
		if(opt==2)
		{
			x=read();
			y=read();
			get_sum(x,y);
		}
		if(opt==3)
		{
			x=read();
			z=read();
			change_son(x,z);
		}
		if(opt==4)
		{
			x=read();
			sum_son(x);
		}
	}
	return 0;
}

拓展 树链剖分求LCA

之前机房同学说不会树剖求LCA(就不说是谁了)

慢了a little但是比倍增、Tarjan什么的好写且好理解(个人认为)

其实还是那两个DFS+LCA

还是整条链不断跳的过程

int LCA(int x,int y)
{
	while(top[x]!=top[y])
	{
		if(deep[top[x]]<deep[top[y]])
		{
			swap(x,y);
		}
		x=fa[top[x]];
	}
	return (deep[x]<deep[y]) ? x : y;
}

用三目压了下行

练习

板子

还是板子

还TM的是板子

可能是写过最长的一篇博客 虽然大部分都是代码水博

人生はどんなに苦しいのか、だからコーヒーはせめて甘いほうがいい

posted @ 2022-09-22 16:07  NinT_W  阅读(38)  评论(2编辑  收藏  举报