【洛谷P6177】Count on a tree II /【模板】树分块
题目
题目链接:https://www.luogu.com.cn/problem/P6177
给定一个 \(n\) 个节点的树,每个节点上有一个整数,\(i\) 号点的整数为 \(val_i\)。
有 \(m\) 次询问,每次给出 \(u',v\),您需要将其解密得到 \(u,v\),并查询 \(u\) 到 \(v\) 的路径上有多少个不同的整数。
解密方式:\(u=u'\operatorname{xor} lastans\)。
\(lastans\) 为上一次询问的答案,若无询问则为 \(0\)。
思路
首先在树上选择 \(\frac{n}{B}\) 个关键点,使得互为祖孙的相邻关键点之间的距离都不超过 \(B\)。这个可以通过每次贪心选择深度最大的点的 \(B\) 级祖先,然后把选择的点的子树割掉。这样每次至少割 \(B\) 个点,选择的关键点的数量最多是 \(\frac{n}{B}\)。
然后对于互为祖孙的关键点之间预处理他们之间路径的颜色集合,扔进 bitset 中。可以先把相邻的求出来,然后递推一下。
对于每次询问 \(x,y\),找到他们的 LCA 点 \(p\),\(u\to p\) 和 \(v\to p\) 的路径都可以拆分为中间一段关键点之间的路径,以及两边零散的 \(O(B)\) 个点。那么就把两边的暴力加进 bitset 中,再或上两条关键点之间路径的 bitset 即可。
时间复杂度 \(O(\frac{n^3}{\omega B^2}+m(B+\frac{n}{\omega}))\),空间复杂度 \(O(\frac{n^3}{\omega B^2})\)。取 \(B=300\) 即可。
代码
#include <bits/stdc++.h>
using namespace std;
const int N=40010,B=300,LG=16;
int n,m,tot,last,head[N],a[N],b[N],id[N],dep[N],f[N][LG+1],nxt[N/B+2][N/B+2];
bitset<N> bt[N/B+2][N/B+2],s;
struct edge
{
int next,to;
}e[N*2];
void add(int from,int to)
{
e[++tot]=(edge){head[from],to};
head[from]=tot;
}
int dfs1(int x,int fa)
{
f[x][0]=fa; dep[x]=dep[fa]+1;
for (int i=1;i<=LG;i++)
f[x][i]=f[f[x][i-1]][i-1];
int maxd=0;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=fa) maxd=max(maxd,dfs1(v,x));
}
if (maxd>=B) id[x]=++tot,maxd=0;
return maxd+1;
}
int lca(int x,int y)
{
if (dep[x]<dep[y]) swap(x,y);
for (int i=LG;i>=0;i--)
if (dep[f[x][i]]>=dep[y]) x=f[x][i];
if (x==y) return x;
for (int i=LG;i>=0;i--)
if (f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
void dfs2(int x,int y)
{
if (id[x])
{
nxt[id[x]][0]=x; nxt[id[x]][1]=y;
for (int i=x;i!=y;i=f[i][0])
bt[id[x]][id[y]][a[i]]=1;
bt[id[x]][id[y]][a[y]]=1;
for (int i=2,z;i<=N/B+1;i++)
{
z=nxt[id[x]][i]=nxt[id[y]][i-1];
bt[id[x]][id[z]]|=bt[id[x]][id[y]];
bt[id[x]][id[z]]|=bt[id[y]][id[z]];
}
y=x;
}
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=f[x][0]) dfs2(v,y);
}
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++)
scanf("%d",&a[i]),b[i]=a[i];
for (int i=1,x,y;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
sort(b+1,b+1+n);
tot=unique(b+1,b+1+n)-b-1;
for (int i=1;i<=n;i++)
a[i]=lower_bound(b+1,b+1+tot,a[i])-b;
tot=0;
dfs1(1,0); dfs2(1,0);
while (m--)
{
int x,y,p;
scanf("%d%d",&x,&y);
x^=last; p=lca(x,y);
s.reset(); s[a[p]]=1;
for (;!id[x] && x!=p;x=f[x][0]) s[a[x]]=1;
for (;!id[y] && y!=p;y=f[y][0]) s[a[y]]=1;
if (id[x])
for (int i=1;;i++)
if (dep[nxt[id[x]][i]]<dep[p])
{
s|=bt[id[x]][id[nxt[id[x]][i-1]]];
x=nxt[id[x]][i-1];
break;
}
if (id[y])
for (int i=1;;i++)
if (dep[nxt[id[y]][i]]<dep[p])
{
s|=bt[id[y]][id[nxt[id[y]][i-1]]];
y=nxt[id[y]][i-1];
break;
}
for (;x!=p;x=f[x][0]) s[a[x]]=1;
for (;y!=p;y=f[y][0]) s[a[y]]=1;
cout<<(last=s.count())<<"\n";
}
return 0;
}