Luogu6177 Count on a tree II/【模板】树分块
https://www.luogu.com.cn/problem/P6177
树分块
思路就是每隔\(S\)个点,取一个关键点,记录每两个关键点之间的信息,用\(bitset\)维护
然后每次询问一条链时,拆成以\(lca\)为上端节点的两条链
对于每条链,关键点之间的信息直接取就行,剩下的单独取(\(bitset\) \(or\) 运算)
如何取关键点?我们从深度最大的非关键点开始枚举,若其\(1-S\)级祖先中没有关键点,那么钦定其\(S\)级祖先为关键点
我用了树剖来计算\(1-S\)级祖先中是否有关键点
然后,就\(T\)了。。。
首先,调整块大小,在不\(MLE\)的情况下尽量接近\(\sqrt n\)
然后我在取关键点时,对于同一深度的点,取关键点顺序胡乱\(random\_shuffle\)一下(不\(rand\)就\(T\)了,\(rand\)了也需要看脸,当然,没有开O2)
最终\(S\)取了\(300\),不开\(O2\)最大点\(1.93s\)卡了过去(开\(O2\)最大点\(1.18s\))
\(Code:\)
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<ctime>
#include<bitset>
#define N 40005
#define S 300
#define ls(x) (x << 1)
#define rs(x) ((x << 1) | 1)
#define lc ls(p),l,mid
#define rc rs(p),mid+1,r
#define bt bitset<N>
using namespace std;
int n,m,x,y,tot,mxdep,q[N],a[N],g[N],dep[N],f[N][22],col[N],w[N],fr[N],d[N << 1],nxt[N << 1];
int cnt,lst,rk[N],ht[N],z[N],st[N],ss[N],dfn[N],id[N],sz[N],son[N],T[N],fg[N];
bool tr[N << 2];
bool ke[N];
bt k[135],ans;
bt s[135][135];
inline int read()
{
int s=0;
char c=getchar();
while (c<'0' || c>'9')
c=getchar();
while ('0'<=c && c<='9')
{
s=s*10+c-'0';
c=getchar();
}
return s;
}
inline void add(int x,int y)
{
tot++;
d[tot]=y;
nxt[tot]=fr[x];
fr[x]=tot;
}
void dfs(int u)
{
int mx=-1;
sz[u]=1;
mxdep=max(mxdep,dep[u]);
g[dep[u]]++;
for (int i=fr[u];i;i=nxt[i])
{
int v=d[i];
if (v==f[u][0])
continue;
f[v][0]=u;
dep[v]=dep[u]+1;
dfs(v);
sz[u]+=sz[v];
if (sz[v]>mx)
{
mx=sz[v];
son[u]=v;
}
}
}
void dfs2(int u,int tp)
{
cnt++;
dfn[cnt]=u;
id[u]=cnt;
T[u]=tp;
if (!son[u])
return;
dfs2(son[u],tp);
for (int i=fr[u];i;i=nxt[i])
{
int v=d[i];
if (v==f[u][0] || v==son[u])
continue;
dfs2(v,v);
}
}
inline int lca(int x,int y)
{
if (dep[x]<dep[y])
swap(x,y);
for (int i=20;i>=0;i--)
if (dep[f[x][i]]>=dep[y])
x=f[x][i];
if (x==y)
return x;
for (int i=20;i>=0;i--)
if (f[x][i]!=f[y][i])
{
x=f[x][i];
y=f[y][i];
}
return f[x][0];
}
void update(int p)
{
if (!p)
return;
tr[p]=tr[ls(p)] | tr[rs(p)];
update(p >> 1);
}
void build(int p,int l,int r)
{
if (l==r)
{
ss[l]=p;
return;
}
int mid=(l+r) >> 1;
build(lc);
build(rc);
tr[p]=tr[ls(p)] | tr[rs(p)];
}
bool calc(int p,int l,int r,int x,int y)
{
if (l==x && r==y)
return tr[p];
int mid=(l+r) >> 1;
if (y<=mid)
return calc(lc,x,y); else
if (x>mid)
return calc(rc,x,y); else
{
if (calc(lc,x,mid))
return true;
if (calc(rc,mid+1,y))
return true;
return false;
}
}
bool check(int x,int y)
{
while (T[x]!=T[y])
{
if (calc(1,1,n,id[T[x]],id[x]))
return true;
x=f[T[x]][0];
}
if (calc(1,1,n,id[y],id[x]))
return true;
return false;
}
inline void fl(int x,int y)
{
while (!ke[x])
{
ans.set(col[x]);
if (x==y)
return;
x=f[x][0];
}
int rx=x,px=-1;
while (dep[ht[x]]>=dep[y])
px=x,x=ht[x];
if (~px)
ans|=s[rk[rx]][rk[px]];
ans.set(col[x]);
while (x!=y)
{
x=f[x][0];
ans.set(col[x]);
}
}
void write(int x)
{
if (x>9)
write(x/10);
putchar(x%10+'0');
}
int main()
{
srand(time(NULL));
n=read(),m=read();
for (int i=1;i<=n;i++)
col[i]=read(),w[i]=col[i];
sort(w+1,w+n+1);
int cc=unique(w+1,w+n+1)-w-1;
for (int i=1;i<=n;i++)
col[i]=lower_bound(w+1,w+cc+1,col[i])-w;
for (int i=1;i<n;i++)
{
x=read(),y=read();
add(x,y),add(y,x);
}
dep[1]=1;
dfs(1);
dfs2(1,1);
for (int j=1;j<=20;j++)
for (int i=1;i<=n;i++)
f[i][j]=f[f[i][j-1]][j-1];
for (int i=1;i<=mxdep;i++)
g[i]+=g[i-1],fg[i]=g[i];
for (int i=1;i<=n;i++)
a[g[dep[i]]--]=i;
for (int i=1;i<=mxdep;i++)
if (fg[i]>fg[i-1])
random_shuffle(a+fg[i-1]+2,a+fg[i]+1);
int p=S;
for (int j=20;j>=0;j--)
if (p>=(1 << j))
{
p-=(1 << j);
q[++q[0]]=j;
}
build(1,1,n);
for (int i=n;i;i--)
{
int u=a[i],v=u;
if (ke[u])
continue;
if (dep[u]>S)
{
for (int j=1;j<=q[0];j++)
v=f[v][q[j]];
st[u]=v;
if (!check(u,v))
{
z[++z[0]]=v;
rk[v]=z[0];
ke[v]=true;
tr[ss[id[v]]]=true;
update(ss[id[v]] >> 1);
}
}
}
for (int i=1;i<=z[0];i++)
{
int u=z[i];
k[i].set(col[u]);
u=f[u][0];
while (!ke[u])
{
k[i].set(col[u]);
u=f[u][0];
if (!u)
break;
}
ht[z[i]]=u;
}
for (int i=1;i<=z[0];i++)
{
s[i][i]=k[i];
int u=z[i];
while (ht[u])
{
s[i][rk[ht[u]]]=s[i][rk[u]] | k[rk[ht[u]]];
u=ht[u];
}
}
for (int i=1;i<=m;i++)
{
if (i!=1)
ans.reset();
x=read(),y=read();
x^=lst;
int kz=lca(x,y);
fl(x,kz),fl(y,kz);
lst=ans.count();
write(lst),putchar('\n');
}
return 0;
}