树分块学习笔记

思想

树分块是一种能解决部分操作树上一条链的一种算法。

回忆下序列上的分块,其最精髓的地方在于将序列分成许多段,如果操作的区间包括了某一段,则直接使用整体处理这一段。我们也要使用某种方法使得操作的链也被分成许多块,但像 dfs 序等并不一定能保证整段的大小稳定。

先设定一个阈值 \(S\),我们要求每一段链的长度接近 \(S\)。一种方法是随机选取 \(\frac n S\) 个点,期望每一段的长度是 \(S\),但是太过玄学,便不使用这种方法。我们先处理出每个节点的深度及其祖先,若当前没有遍历的最深节点\(1\sim S\) 级祖先都没有被选取,则将其 \(S\) 级祖先选取。因为深度从大到小遍历,所以每个选取的点至少会覆盖 \(S\) 个节点,所以至多会有 \(\frac n S\) 个节点被选取,称这些被选取的点为关键点。而相邻两个关键点(需满足这两个点的 \(LCA\) 为其中一个点)间的链便相当于序列上分块的整段。

接下来处理出关键点间两两的答案(两点间仍需满足这两个点的 \(LCA\) 为其中一个点),为了防止查询时时间复杂度过大。预处理的时间复杂度为 \(O(\frac {n^2}{S^2}W)\),其中 \(W\) 为合并两段区间答案的复杂度。

剩下按照分块的套路,同时将整块答案与散块答案相加即可。具体而言,就是先找到 \(u,v\)\(LCA\),分别处理 \(u\sim LCA,v\sim LCA\) 即可。(有的题目中需要注意 \(LCA\) 不能被算两次)

实战

P6177 Count on a tree II/【模板】树分块

bitset 来表示有哪些颜色出现过,合并的时间复杂度为 \(O(\frac n \omega)\)

#pragma GCC optimize(3)
#include<iostream>
#include<set>
#include<bitset>
#include<algorithm>
#include<vector>
#include<cmath>
using namespace std;
#define N 40010
#define S 5010
#define NS 550
bitset<N> bt[NS][NS];
bitset<N> ans;
int n,m,s,u,v,a[N],cnt,b[N],cnt2,w[N],ys[N];
int fat[N][21],dep[N],vis[N],top[N];
int len,f[N];
bitset<N> sum[N];
vector<int> g[N];
struct node
{
	int w;
	friend bool operator<(const node a,const node b)
	{
		return dep[a.w]>dep[b.w];
	}
};
multiset<node> st;
void dfs(int u,int fa)
{
	fat[u][0]=fa;
	dep[u]=dep[fa]+1;
	for(int v:g[u])
	if(v!=fa)
	{
		dfs(v,u);
	}
}
void slt()
{
//	int s=sqrt(n);
	s=220;
	while(!st.empty())
	{
		int u=(*st.begin()).w;
		st.erase(st.begin());
		int w=u;
		bool fl=0;
		for(int i=1;i<=s;i++)
		{
			w=fat[w][0];
			if(vis[w])
			{
				fl=1;
				break;
			}
		}
		if(!fl&&w!=0)
		{
			vis[w]=1;
			b[++cnt]=w;
		}
		else if(u==1&&!vis[1])
		{
			vis[1]=1;
			b[++cnt]=1;
		}
	}
}
void dfs2(int u,int tpf)
{
	sum[u].set(a[u]);
	if(vis[u])
	{
		int pos=0;
		for(int i=1;i<=cnt;i++)
		if(b[i]==u)
		{
			pos=i;
			break;
		}
		ys[u]=pos;
		if(u!=1)
		{
			bt[min(pos,w[cnt])][max(pos,w[cnt])]=sum[u];
			for(int i=1;i<=cnt-1;i++)
			{
				bt[min(w[i],pos)][max(w[i],pos)]=(bt[min(w[i],w[cnt])][max(w[i],w[cnt])]|bt[min(w[cnt],pos)][max(w[cnt],pos)]);
			}
		}
		w[++cnt]=pos;
		sum[u].reset();
		sum[u].set(a[u]);
		tpf=u;
	}
	top[u]=tpf;
	for(int v:g[u])
	if(v!=fat[u][0])
	{
		sum[v]=sum[u];
		dfs2(v,tpf);
	}
	if(vis[u])
	 cnt--;
}
int getlca(int x,int y)
{
	if(dep[x]>dep[y])
	 swap(x,y);
	for(int i=20;i>=0;i--)
	if(dep[fat[y][i]]>=dep[x])
	 y=fat[y][i];
	if(x==y)
	{
		return x;
	}
	for(int i=20;i>=0;i--)
	if(fat[x][i]!=fat[y][i])
	{
		x=fat[x][i];
		y=fat[y][i];
	}
	return fat[x][0];
}
int getans(int u,int v)
{
	ans.reset();
	int l=getlca(u,v);
	if(top[u]==top[v])
	{
		if(dep[u]>dep[v])
		 swap(u,v);
		while(v!=l)
		{
			ans.set(a[v]);
			v=fat[v][0];
		}
		while(u!=l)
		{
			ans.set(a[u]);
			u=fat[u][0];
		}
		ans.set(a[l]);
		return ans.count();
	}
	while(dep[u]>dep[top[u]]&&dep[u]>dep[l])
	{
		ans.set(a[u]);
		u=fat[u][0];
	}
	while(dep[v]>dep[top[v]]&&dep[v]>dep[l])
	{
		ans.set(a[v]);
		v=fat[v][0];
	}
	ans.set(a[u]);
	ans.set(a[v]);
	if(dep[u]>dep[l])
	{
		int uu=u;
		while(dep[top[fat[uu][0]]]>dep[l])
		{
			uu=top[fat[uu][0]];
		}
		ans|=bt[min(ys[u],ys[uu])][max(ys[u],ys[uu])];
		u=uu;
		while(dep[u]>dep[l])
		{
			ans.set(a[u]);
			u=fat[u][0];
		}
	}
	if(dep[v]>dep[l])
	{
		int vv=v;
		while(dep[top[fat[vv][0]]]>dep[l])
		{
			vv=top[fat[vv][0]];
		}
		ans|=bt[min(ys[v],ys[vv])][max(ys[v],ys[vv])];
		v=vv;
		while(dep[v]>dep[l])
		{
			ans.set(a[v]);
			v=fat[v][0];
		}
	}
	ans.set(a[l]);
	return ans.count();
}
int main()
{
	ios::sync_with_stdio(false);
	cin.tie(0);
	cout.tie(0);
	cin>>n>>m;
	for(int i=1;i<=n;i++)
	{
		cin>>a[i];
		f[++len]=a[i];
	}
	for(int i=1;i<=n-1;i++)
	{
		cin>>u>>v;
		g[u].push_back(v);
		g[v].push_back(u);
	}
	sort(f+1,f+len+1);
	len=unique(f+1,f+len+1)-f-1;
	for(int i=1;i<=n;i++)
	{
		a[i]=lower_bound(f+1,f+len+1,a[i])-f;
	}
	dfs(1,0);
	st.clear();
	for(int i=1;i<=n;i++)
	{
		st.insert((node){i});
	}
	slt();
	ans.reset();
	dfs2(1,0);
	for(int i=1;i<=20;i++)
	for(int j=1;j<=n;j++)
	 fat[j][i]=fat[fat[j][i-1]][i-1];
	int lasans=0;
	for(int i=1;i<=m;i++)
	{
		cin>>u>>v;
		u^=lasans;
		cout<<(lasans=getans(u,v))<<"\n";
	}
}
posted @ 2023-08-18 20:16  Lyz09  阅读(144)  评论(0编辑  收藏  举报