树分块学习笔记
思想
树分块是一种能解决部分操作树上一条链的一种算法。
回忆下序列上的分块,其最精髓的地方在于将序列分成许多段,如果操作的区间包括了某一段,则直接使用整体处理这一段。我们也要使用某种方法使得操作的链也被分成许多块,但像 dfs 序等并不一定能保证整段的大小稳定。
先设定一个阈值 \(S\),我们要求每一段链的长度接近 \(S\)。一种方法是随机选取 \(\frac n S\) 个点,期望每一段的长度是 \(S\),但是太过玄学,便不使用这种方法。我们先处理出每个节点的深度及其祖先,若当前没有遍历的最深节点的 \(1\sim S\) 级祖先都没有被选取,则将其 \(S\) 级祖先选取。因为深度从大到小遍历,所以每个选取的点至少会覆盖 \(S\) 个节点,所以至多会有 \(\frac n S\) 个节点被选取,称这些被选取的点为关键点。而相邻两个关键点(需满足这两个点的 \(LCA\) 为其中一个点)间的链便相当于序列上分块的整段。
接下来处理出关键点间两两的答案(两点间仍需满足这两个点的 \(LCA\) 为其中一个点),为了防止查询时时间复杂度过大。预处理的时间复杂度为 \(O(\frac {n^2}{S^2}W)\),其中 \(W\) 为合并两段区间答案的复杂度。
剩下按照分块的套路,同时将整块答案与散块答案相加即可。具体而言,就是先找到 \(u,v\) 的 \(LCA\),分别处理 \(u\sim LCA,v\sim LCA\) 即可。(有的题目中需要注意 \(LCA\) 不能被算两次)
实战
P6177 Count on a tree II/【模板】树分块
用 bitset
来表示有哪些颜色出现过,合并的时间复杂度为 \(O(\frac n \omega)\)。
#pragma GCC optimize(3)
#include<iostream>
#include<set>
#include<bitset>
#include<algorithm>
#include<vector>
#include<cmath>
using namespace std;
#define N 40010
#define S 5010
#define NS 550
bitset<N> bt[NS][NS];
bitset<N> ans;
int n,m,s,u,v,a[N],cnt,b[N],cnt2,w[N],ys[N];
int fat[N][21],dep[N],vis[N],top[N];
int len,f[N];
bitset<N> sum[N];
vector<int> g[N];
struct node
{
int w;
friend bool operator<(const node a,const node b)
{
return dep[a.w]>dep[b.w];
}
};
multiset<node> st;
void dfs(int u,int fa)
{
fat[u][0]=fa;
dep[u]=dep[fa]+1;
for(int v:g[u])
if(v!=fa)
{
dfs(v,u);
}
}
void slt()
{
// int s=sqrt(n);
s=220;
while(!st.empty())
{
int u=(*st.begin()).w;
st.erase(st.begin());
int w=u;
bool fl=0;
for(int i=1;i<=s;i++)
{
w=fat[w][0];
if(vis[w])
{
fl=1;
break;
}
}
if(!fl&&w!=0)
{
vis[w]=1;
b[++cnt]=w;
}
else if(u==1&&!vis[1])
{
vis[1]=1;
b[++cnt]=1;
}
}
}
void dfs2(int u,int tpf)
{
sum[u].set(a[u]);
if(vis[u])
{
int pos=0;
for(int i=1;i<=cnt;i++)
if(b[i]==u)
{
pos=i;
break;
}
ys[u]=pos;
if(u!=1)
{
bt[min(pos,w[cnt])][max(pos,w[cnt])]=sum[u];
for(int i=1;i<=cnt-1;i++)
{
bt[min(w[i],pos)][max(w[i],pos)]=(bt[min(w[i],w[cnt])][max(w[i],w[cnt])]|bt[min(w[cnt],pos)][max(w[cnt],pos)]);
}
}
w[++cnt]=pos;
sum[u].reset();
sum[u].set(a[u]);
tpf=u;
}
top[u]=tpf;
for(int v:g[u])
if(v!=fat[u][0])
{
sum[v]=sum[u];
dfs2(v,tpf);
}
if(vis[u])
cnt--;
}
int getlca(int x,int y)
{
if(dep[x]>dep[y])
swap(x,y);
for(int i=20;i>=0;i--)
if(dep[fat[y][i]]>=dep[x])
y=fat[y][i];
if(x==y)
{
return x;
}
for(int i=20;i>=0;i--)
if(fat[x][i]!=fat[y][i])
{
x=fat[x][i];
y=fat[y][i];
}
return fat[x][0];
}
int getans(int u,int v)
{
ans.reset();
int l=getlca(u,v);
if(top[u]==top[v])
{
if(dep[u]>dep[v])
swap(u,v);
while(v!=l)
{
ans.set(a[v]);
v=fat[v][0];
}
while(u!=l)
{
ans.set(a[u]);
u=fat[u][0];
}
ans.set(a[l]);
return ans.count();
}
while(dep[u]>dep[top[u]]&&dep[u]>dep[l])
{
ans.set(a[u]);
u=fat[u][0];
}
while(dep[v]>dep[top[v]]&&dep[v]>dep[l])
{
ans.set(a[v]);
v=fat[v][0];
}
ans.set(a[u]);
ans.set(a[v]);
if(dep[u]>dep[l])
{
int uu=u;
while(dep[top[fat[uu][0]]]>dep[l])
{
uu=top[fat[uu][0]];
}
ans|=bt[min(ys[u],ys[uu])][max(ys[u],ys[uu])];
u=uu;
while(dep[u]>dep[l])
{
ans.set(a[u]);
u=fat[u][0];
}
}
if(dep[v]>dep[l])
{
int vv=v;
while(dep[top[fat[vv][0]]]>dep[l])
{
vv=top[fat[vv][0]];
}
ans|=bt[min(ys[v],ys[vv])][max(ys[v],ys[vv])];
v=vv;
while(dep[v]>dep[l])
{
ans.set(a[v]);
v=fat[v][0];
}
}
ans.set(a[l]);
return ans.count();
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
cin>>n>>m;
for(int i=1;i<=n;i++)
{
cin>>a[i];
f[++len]=a[i];
}
for(int i=1;i<=n-1;i++)
{
cin>>u>>v;
g[u].push_back(v);
g[v].push_back(u);
}
sort(f+1,f+len+1);
len=unique(f+1,f+len+1)-f-1;
for(int i=1;i<=n;i++)
{
a[i]=lower_bound(f+1,f+len+1,a[i])-f;
}
dfs(1,0);
st.clear();
for(int i=1;i<=n;i++)
{
st.insert((node){i});
}
slt();
ans.reset();
dfs2(1,0);
for(int i=1;i<=20;i++)
for(int j=1;j<=n;j++)
fat[j][i]=fat[fat[j][i-1]][i-1];
int lasans=0;
for(int i=1;i<=m;i++)
{
cin>>u>>v;
u^=lasans;
cout<<(lasans=getans(u,v))<<"\n";
}
}