【洛谷P4211】LCA

题目

题目链接:https://www.luogu.com.cn/problem/P4211
给出一个 \(n\) 个节点的有根树(编号为 \(0\)\(n-1\),根节点为 \(0\))。
一个点的深度定义为这个节点到根的距离 \(+1\)
\(dep[i]\) 表示点i的深度,\(LCA(i,j)\) 表示 \(i\)\(j\) 的最近公共祖先。
\(q\) 次询问,每次询问给出 \(l\ r\ z\),求 \(\sum_{i=l}^r dep[LCA(i,z)]\)

思路

首先 \(dep[lca(x,y)]\) 等价于把 \(x\) 的所有祖先节点标记为 1,然后求 \(y\) 的祖先节点的权值和。
那么 \(\sum^{r}_{i=l} dep[lca(x,i)]\) 等价于把 \([l,r]\) 的所有点的祖先节点全部加一,求 \(x\) 的祖先节点的权值和。
把询问拆成 \(1\sim r\) 的和减去 \(1\sim l-1\) 的和,然后按照编号从小到大枚举点,树剖 + 线段树将这个点到 1 的路径的点权值全部加一。
对于询问就直接求到 1 的权值和即可。
时间复杂度 \(O(n\log^2 n)\)

代码

#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;

const int N=50010,MOD=201314;
int head[N],son[N],fa[N],size[N],id[N],rk[N],top[N];
int n,Q,tot;
vector<int> pos[N];

struct edge
{
	int next,to;
}e[N];

struct Query
{
	int x,l,r,ans;
}ask[N];

void add(int from,int to)
{
	e[++tot].to=to;
	e[tot].next=head[from];
	head[from]=tot;
}

void dfs1(int x)
{
	size[x]++;
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		dfs1(v);
		size[x]+=size[v];
		if (size[v]>size[son[x]]) son[x]=v;
	}
}

void dfs2(int x,int tp)
{
	id[x]=++tot; rk[tot]=x; top[x]=tp;
	if (son[x]) dfs2(son[x],tp);
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=son[x]) dfs2(v,v);
	}
}

struct SegTree
{
	int l[N*4],r[N*4],len[N*4],sum[N*4],lazy[N*4];
	
	void pushdown(int x)
	{
		if (lazy[x])
		{
			sum[x*2]=(sum[x*2]+lazy[x]*len[x*2])%MOD;
			sum[x*2+1]=(sum[x*2+1]+lazy[x]*len[x*2+1])%MOD;
			lazy[x*2]=(lazy[x*2]+lazy[x])%MOD;
			lazy[x*2+1]=(lazy[x*2+1]+lazy[x])%MOD;
			lazy[x]=0;
		}
	}
	
	void build(int x,int ql,int qr)
	{
		l[x]=ql; r[x]=qr; len[x]=qr-ql+1;
		if (ql==qr) return;
		int mid=(ql+qr)>>1;
		build(x*2,ql,mid);
		build(x*2+1,mid+1,qr);
	}
	
	void pushup(int x)
	{
		sum[x]=(sum[x*2]+sum[x*2+1])%MOD;
	}
	
	void update(int x,int ql,int qr)
	{
		if (l[x]==ql && r[x]==qr)
		{
			sum[x]=(sum[x]+len[x])%MOD; lazy[x]++;
			return;
		}
		pushdown(x);
		int mid=(l[x]+r[x])>>1;
		if (qr<=mid) update(x*2,ql,qr);
		else if (ql>mid) update(x*2+1,ql,qr);
		else update(x*2,ql,mid),update(x*2+1,mid+1,qr);
		pushup(x);
	}
	
	int query(int x,int ql,int qr)
	{
		if (l[x]==ql && r[x]==qr)
			return sum[x];
		pushdown(x);
		int mid=(l[x]+r[x])>>1;
		if (qr<=mid) return query(x*2,ql,qr);
		if (ql>mid) return query(x*2+1,ql,qr);
		return query(x*2,ql,mid)+query(x*2+1,mid+1,qr);
	}
}seg;

void Update(int x)
{
	while (x)
	{
		seg.update(1,id[top[x]],id[x]);
		x=fa[top[x]];
	}
}

int Query(int x)
{
	int ans=0;
	while (x)
	{
		ans=(ans+seg.query(1,id[top[x]],id[x]))%MOD;
		x=fa[top[x]];
	}
	return ans;
}

int main()
{
	memset(head,-1,sizeof(head));
	scanf("%d%d",&n,&Q);
	for (int i=2,x;i<=n;i++)
	{
		scanf("%d",&x);
		add(x+1,i); fa[i]=x+1;
	}
	for (int i=1;i<=Q;i++)
	{
		scanf("%d%d%d",&ask[i].l,&ask[i].r,&ask[i].x);
		ask[i].r++; ask[i].x++;
		pos[ask[i].l].push_back(i);
		pos[ask[i].r].push_back(i);
	}
	tot=0;
	dfs1(1); dfs2(1,1);
	seg.build(1,1,n);
	for (int i=0;i<=n;i++)
	{
		Update(i);
		for (int j=0;j<pos[i].size();j++)
		{
			int k=pos[i][j],s=Query(ask[k].x);
			if (ask[k].l==i) ask[k].ans-=s;
				else ask[k].ans+=s;
		}
	}
	for (int i=1;i<=Q;i++)
		printf("%d\n",(ask[i].ans%MOD+MOD)%MOD);
	return 0;
}
posted @ 2020-06-13 10:10  stoorz  阅读(152)  评论(0编辑  收藏  举报