【P5903】【模板】树上 k 级祖先

题目

题目链接:https://www.luogu.com.cn/problem/P5903
给定一棵 \(n\) 个点的有根树。

\(q\) 次询问,第 \(i\) 次询问给定 \(x_i, k_i\),要求点 \(x_i\)\(k_i\) 级祖先。

思路

长剖模板题。
长链剖分是按照子树内最长的链来树剖。在求树上 \(k\) 级祖先时,可以做到 \(O(n\log n)\) 预处理,单次 \(O(1)\) 查询。
首先我们 dfs 一遍,求出每一个节点的 \(2^k\) 级祖先,并长剖。对于每一条长链的顶端节点,假设这条长链长度为 \(d\),那么在这个节点记录从这个节点开始,往上 \(d\) 级祖先,以及按照长链往下 \(d\) 级子孙。
询问时,我们先往 \(x\) 上跳 \(2^{k'}\) 级祖先,满足 \(2^{k'}\leq k\) 并且尽量大。根据长链剖分的性质,这个点所在长链长度一定不小于 \(2^{k'}\)
由于 \(k-2^{k'}\) 一定小于 \(2^{k'}\),所以我们在这条链的顶端也一定记录了 \(x\)\(k\) 级祖先的信息。所以直接跳到这条长链的顶端,根据剩余步数选择往下或往上跳若干步即可。
时间复杂度 \(O(n\log n+Q)\)

代码

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned int uint;

const int N=500010,LG=20;
int n,Q,rt,last,tot,lg[N],head[N],maxd[N],son[N],dep[N],top[N],f[N][LG+1];
uint seed;
ll ans;
vector<int> up[N],down[N];

inline uint get(uint x) {
	x ^= x << 13;
	x ^= x >> 17;
	x ^= x << 5;
	return seed = x; 
}

struct edge
{
	int next,to;
}e[N];

void add(int from,int to)
{
	e[++tot]=(edge){head[from],to};
	head[from]=tot;
}

void dfs1(int x)
{
	dep[x]=dep[f[x][0]]+1;
	for (int i=1;i<=LG;i++)
		f[x][i]=f[f[x][i-1]][i-1];
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=f[x][0])
		{
			dfs1(v);
			maxd[x]=max(maxd[x],maxd[v]+1);
			if (maxd[v]>maxd[son[x]]) son[x]=v;
		}
	}
}

void dfs2(int x,int tp)
{
	top[x]=tp;
	if (x==tp)
	{
		for (int i=x,j=0;j<=maxd[x];j++,i=f[i][0])
			up[x].push_back(i);
		for (int i=x,j=0;j<=maxd[x];j++,i=son[i])
			down[x].push_back(i);
	}
	if (son[x]) dfs2(son[x],tp);
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=f[x][0] && v!=son[x])
			dfs2(v,v);
	}
}

int query(int x,int k)
{
	if (!k) return x;
	x=f[x][lg[k]]; k-=(1<<lg[k]);
	k-=dep[x]-dep[top[x]]; x=top[x];
	if (k>=0) return up[x][k];
		else return down[x][-k];
}

int main()
{
	memset(head,-1,sizeof(head));
	scanf("%d%d",&n,&Q);
	scanf("%u",&seed);
	for (int i=1;i<=n;i++)
	{
		scanf("%d",&f[i][0]);
		if (!f[i][0]) rt=i;
		add(f[i][0],i);
	}
	for (int i=2;i<=n;i++)
		lg[i]=lg[i>>1]+1;
	dfs1(rt); dfs2(rt,rt);
	for (int i=1;i<=Q;i++)
	{
		int x=(get(seed)^last)%n+1;
		int k=(get(seed)^last)%dep[x];
		last=query(x,k);
		ans^=1LL*i*last;
	}
	printf("%lld",ans);
	return 0;
}
posted @ 2020-12-30 15:41  stoorz  阅读(99)  评论(0编辑  收藏  举报