【洛谷P6177】Count on a tree II /【模板】树分块

题目

题目链接:https://www.luogu.com.cn/problem/P6177
给定一个 \(n\) 个节点的树,每个节点上有一个整数,\(i\) 号点的整数为 \(val_i\)
\(m\) 次询问,每次给出 \(u',v\),您需要将其解密得到 \(u,v\),并查询 \(u\)\(v\) 的路径上有多少个不同的整数。
解密方式:\(u=u'\operatorname{xor} lastans\)
\(lastans\) 为上一次询问的答案,若无询问则为 \(0\)

思路

首先在树上选择 \(\frac{n}{B}\) 个关键点,使得互为祖孙的相邻关键点之间的距离都不超过 \(B\)。这个可以通过每次贪心选择深度最大的点的 \(B\) 级祖先,然后把选择的点的子树割掉。这样每次至少割 \(B\) 个点,选择的关键点的数量最多是 \(\frac{n}{B}\)
然后对于互为祖孙的关键点之间预处理他们之间路径的颜色集合,扔进 bitset 中。可以先把相邻的求出来,然后递推一下。
对于每次询问 \(x,y\),找到他们的 LCA 点 \(p\)\(u\to p\)\(v\to p\) 的路径都可以拆分为中间一段关键点之间的路径,以及两边零散的 \(O(B)\) 个点。那么就把两边的暴力加进 bitset 中,再或上两条关键点之间路径的 bitset 即可。
时间复杂度 \(O(\frac{n^3}{\omega B^2}+m(B+\frac{n}{\omega}))\),空间复杂度 \(O(\frac{n^3}{\omega B^2})\)。取 \(B=300\) 即可。

代码

#include <bits/stdc++.h>
using namespace std;

const int N=40010,B=300,LG=16;
int n,m,tot,last,head[N],a[N],b[N],id[N],dep[N],f[N][LG+1],nxt[N/B+2][N/B+2];
bitset<N> bt[N/B+2][N/B+2],s;

struct edge
{
	int next,to;
}e[N*2];

void add(int from,int to)
{
	e[++tot]=(edge){head[from],to};
	head[from]=tot;
}

int dfs1(int x,int fa)
{
	f[x][0]=fa; dep[x]=dep[fa]+1;
	for (int i=1;i<=LG;i++)
		f[x][i]=f[f[x][i-1]][i-1];
	int maxd=0;
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=fa) maxd=max(maxd,dfs1(v,x));
	}
	if (maxd>=B) id[x]=++tot,maxd=0;
	return maxd+1;
}

int lca(int x,int y)
{
	if (dep[x]<dep[y]) swap(x,y);
	for (int i=LG;i>=0;i--)
		if (dep[f[x][i]]>=dep[y]) x=f[x][i];
	if (x==y) return x;
	for (int i=LG;i>=0;i--)
		if (f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
	return f[x][0];
}

void dfs2(int x,int y)
{
	if (id[x])
	{
		nxt[id[x]][0]=x; nxt[id[x]][1]=y;
		for (int i=x;i!=y;i=f[i][0])
			bt[id[x]][id[y]][a[i]]=1;
		bt[id[x]][id[y]][a[y]]=1;
		for (int i=2,z;i<=N/B+1;i++)
		{
			z=nxt[id[x]][i]=nxt[id[y]][i-1];
			bt[id[x]][id[z]]|=bt[id[x]][id[y]];
			bt[id[x]][id[z]]|=bt[id[y]][id[z]];
		}
		y=x;
	}
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=f[x][0]) dfs2(v,y);
	}
}

int main()
{
	memset(head,-1,sizeof(head));
	scanf("%d%d",&n,&m);
	for (int i=1;i<=n;i++)
		scanf("%d",&a[i]),b[i]=a[i];
	for (int i=1,x,y;i<n;i++)
	{
		scanf("%d%d",&x,&y);
		add(x,y); add(y,x);
	}
	sort(b+1,b+1+n);
	tot=unique(b+1,b+1+n)-b-1;
	for (int i=1;i<=n;i++)
		a[i]=lower_bound(b+1,b+1+tot,a[i])-b;
	tot=0;
	dfs1(1,0); dfs2(1,0);
	while (m--)
	{
		int x,y,p;
		scanf("%d%d",&x,&y);
		x^=last; p=lca(x,y);
		s.reset(); s[a[p]]=1;
		for (;!id[x] && x!=p;x=f[x][0]) s[a[x]]=1;
		for (;!id[y] && y!=p;y=f[y][0]) s[a[y]]=1;
		if (id[x])
			for (int i=1;;i++)
				if (dep[nxt[id[x]][i]]<dep[p])
				{
					s|=bt[id[x]][id[nxt[id[x]][i-1]]];
					x=nxt[id[x]][i-1];
					break;
				}
		if (id[y])
			for (int i=1;;i++)
				if (dep[nxt[id[y]][i]]<dep[p])
				{
					s|=bt[id[y]][id[nxt[id[y]][i-1]]];
					y=nxt[id[y]][i-1];
					break;
				}
		for (;x!=p;x=f[x][0]) s[a[x]]=1;
		for (;y!=p;y=f[y][0]) s[a[y]]=1;
		cout<<(last=s.count())<<"\n";
	}
	return 0;
}
posted @ 2021-08-20 09:24  stoorz  阅读(82)  评论(0编辑  收藏  举报