【洛谷P3384】【模板】树链剖分

题目大意:

题目链接:https://www.luogu.org/problem/P3384
如题,已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:

操作1: 格式:1 x y z1\ x\ y\ z 表示将树从xxyy结点最短路径上所有节点的值都加上zz

操作2: 格式: 2 x y2\ x\ y 表示求树从xxyy结点最短路径上所有节点的值之和

操作3: 格式: 3 x z3\ x\ z 表示将以xx为根节点的子树内所有节点值都加上zz

操作4: 格式: 4 x4\ x 表示求以xx为根节点的子树内所有节点值之和


思路:

树链剖分模板题。
大部分思路、学习过程来自这里:https://www.luogu.org/blog/communist/shu-lian-pou-fen-yang-xie
我们定义如下内容

  • 重儿子:一个结点的儿子中,子树最大的儿子
  • 轻儿子:该节点除了重儿子以外的儿子
  • 重边:重儿子与他父亲的连边
  • 轻边:轻儿子与他父亲的连边
  • 重链:多条重链连接起来的路径
  • 轻链:多条轻边连接起来的路径

之后我们需要进行两次dfsdfs
第一次dfsdfs我们求出每一个节点的父亲,深度,以及子树大小。分别用fa[x],dep[x],size[x]fa[x],dep[x],size[x]记录。
同时还要记录除叶子外每个节点的重儿子。用son[x]son[x]记录。
然后第二次dfsdfs我们将重链优先编号,这样用数据结构维护时就可以更加方便。同时,我们需要满任意节点的子树编号依然为一段连续的区间。同时记录下每一条重链的起始点,也就是该重链的深度最浅的点,然后记录下每一个点的编号(注意这个编号和dfsdfs序略有不同),以及该编号对应的节点。分别用top[x],id[x],rk[x]top[x],id[x],rk[x]记录。

void dfs1(int x,int f)
{
	fa[x]=f;
	dep[x]=dep[fa[x]]+1;
	size[x]=1;
	for (int i=head[x];~i;i=e[i].next)
	{
		int y=e[i].to;
		if (y!=fa[x])
		{
			dfs1(y,x);
			size[x]+=size[y];
			if (size[y]>size[son[x]]) son[x]=y;
		}
	}
}

void dfs2(int x,int tp)
{
	top[x]=tp;
	id[x]=++cnt;
	rk[cnt]=x;
	if (son[x]) dfs2(son[x],tp);
	for (int i=head[x];~i;i=e[i].next)
	{
		int y=e[i].to;
		if (y!=fa[x] && y!=son[x]) dfs2(y,y);
	}
}

然后接下来我们就要处理操作了


1.将树从xxyy结点最短路径上所有节点的值都加上zz

我们已经保证了每一条重链编号是连续的,所以我们每次在线段树中只要维护若干个区间加。每次选取x,yx,y两点中深度较深的点,然后将该点到该点所在重链的toptop区间加zz即可,然后把xx赋值为fa[top[x]]fa[top[x]]

2.求树从xxyy结点最短路径上所有节点的值之和

和操作1的思路是相同的,每次求xx到其重链toptop的区间和

3.将以xx为根节点的子树内所有节点值都加上zz

由于序列只是在dfsdfs序上稍加修改,我们依然可以保证一棵子树任然在同一个区间。
那么如果这棵子树的根的编号为xx,我们已经处理出了该子树的大小size[x]size[x],所以我们要进行区间加的区间为[x,x+size[x]1][x,x+size[x]-1]

4.求以xx为根节点的子树内所有节点值之和

这个其实就是线段树的模板,区间查询[x,x+size[x]1][x,x+size[x]-1]的和即可。

可以证明树链剖分的时间复杂度为O(nlog2n)O(n\log^2 n)


代码:

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

const int N=100010;
int size[N],fa[N],dep[N],id[N],rk[N],son[N],top[N],a[N],head[N];
int n,m,root,MOD,opt,cnt,tot;

struct edge
{
	int next,to;
}e[N*2];

struct Treenode
{
	int l,r,sum,lazy;
};

struct Tree
{
	Treenode tree[N*4];
	
	int len(int x)
	{
		return tree[x].r-tree[x].l+1;
	}
	
	void pushup(int x)
	{
		tree[x].sum=(tree[x*2].sum+tree[x*2+1].sum)%MOD;
	}
	
	void pushdown(int x)
	{
		if (tree[x].lazy)
		{
			tree[x*2].lazy=(tree[x*2].lazy+tree[x].lazy)%MOD;
			tree[x*2+1].lazy=(tree[x*2+1].lazy+tree[x].lazy)%MOD;
			tree[x*2].sum=(tree[x*2].sum+tree[x].lazy*len(x*2))%MOD;
			tree[x*2+1].sum=(tree[x*2+1].sum+tree[x].lazy*len(x*2+1))%MOD;
			tree[x].lazy=0;
		}
	}
	
	void build(int x)
	{
		if (tree[x].l==tree[x].r)
		{
			tree[x].sum=a[rk[tree[x].l]]%MOD;
			return;
		}
		int mid=(tree[x].l+tree[x].r)>>1;
		tree[x*2].l=tree[x].l;
		tree[x*2].r=mid;
		tree[x*2+1].l=mid+1;
		tree[x*2+1].r=tree[x].r;
		build(x*2); build(x*2+1);
		pushup(x);
	}
	
	void update(int x,int l,int r,int val)
	{
		if (tree[x].l==l && tree[x].r==r)
		{
			tree[x].sum=(tree[x].sum+val*len(x))%MOD;
			tree[x].lazy=(tree[x].lazy+val)%MOD;
			return;
		}
		pushdown(x);
		int mid=(tree[x].l+tree[x].r)>>1;
		if (r<=mid) update(x*2,l,r,val);
		else if (l>mid) update(x*2+1,l,r,val);
		else update(x*2,l,mid,val),update(x*2+1,mid+1,r,val);
		pushup(x);
	}
	
	void addrange(int x,int y,int k)
	{
		while (top[x]!=top[y])
		{
			if (dep[top[x]]<dep[top[y]]) swap(x,y);
			update(1,id[top[x]],id[x],k);
			x=fa[top[x]];
		}
		if (id[x]>id[y]) update(1,id[y],id[x],k);
			else update(1,id[x],id[y],k);
	}
	
	int ask(int x,int l,int r)
	{
		if (tree[x].l==l && tree[x].r==r) return tree[x].sum;
		pushdown(x);
		int mid=(tree[x].l+tree[x].r)>>1;
		if (r<=mid) return ask(x*2,l,r);
		if (l>mid) return ask(x*2+1,l,r);
		return (ask(x*2,l,mid)+ask(x*2+1,mid+1,r))%MOD;
	}
	
	int askrange(int x,int y)
	{
		int ans=0;
		while (top[x]!=top[y])
		{
			if (dep[top[x]]<dep[top[y]]) swap(x,y);
			ans=(ans+ask(1,id[top[x]],id[x]))%MOD;
			x=fa[top[x]];
		}
		if (id[x]>id[y]) ans=(ans+ask(1,id[y],id[x]))%MOD;
			else ans=(ans+ask(1,id[x],id[y]))%MOD;
		return ans;
	}
}Tree;

void add(int from,int to)
{
	e[++tot].to=to;
	e[tot].next=head[from];
	head[from]=tot;
}

void dfs1(int x,int f)
{
	fa[x]=f;
	dep[x]=dep[fa[x]]+1;
	size[x]=1;
	for (int i=head[x];~i;i=e[i].next)
	{
		int y=e[i].to;
		if (y!=fa[x])
		{
			dfs1(y,x);
			size[x]+=size[y];
			if (size[y]>size[son[x]]) son[x]=y;
		}
	}
}

void dfs2(int x,int tp)
{
	top[x]=tp;
	id[x]=++cnt;
	rk[cnt]=x;
	if (son[x]) dfs2(son[x],tp);
	for (int i=head[x];~i;i=e[i].next)
	{
		int y=e[i].to;
		if (y!=fa[x] && y!=son[x]) dfs2(y,y);
	}
}

int main()
{
	memset(head,-1,sizeof(head));
	scanf("%d%d%d%d",&n,&m,&root,&MOD);
	for (int i=1;i<=n;i++)
		scanf("%d",&a[i]);
	for (int i=1,x,y;i<n;i++)
	{
		scanf("%d%d",&x,&y);
		add(x,y); add(y,x);
	}
	dfs1(root,0);
	dfs2(root,root);
	Tree.tree[1].l=1; Tree.tree[1].r=n;
	Tree.build(1);
	for (int i=1,x,y,z;i<=m;i++)
	{
		scanf("%d",&opt);
		if (opt==1)
		{
			scanf("%d%d%d",&x,&y,&z);
			Tree.addrange(x,y,z);
		}
		if (opt==2)
		{
			scanf("%d%d",&x,&y);
			printf("%d\n",Tree.askrange(x,y));
		}
		if (opt==3)
		{
			scanf("%d%d",&x,&y);
			Tree.update(1,id[x],id[x]+size[x]-1,y);
		}
		if (opt==4)
		{
			scanf("%d",&x);
			printf("%d\n",Tree.ask(1,id[x],id[x]+size[x]-1));
		}
	}
	return 0;
}
posted @ 2019-08-19 10:36  全OI最菜  阅读(135)  评论(0编辑  收藏  举报