UOJ#33. 【UR #2】树上GCD 点分治 莫比乌斯反演

原文链接https://www.cnblogs.com/zhouzhendong/p/UOJ33.html

题解

  首先我们把问题转化成处理一个数组 ans ,其中 ans[i] 表示 d(u,a) 和 d(v,a) 同时为 i 的倍数的 (u,v) 个数。(最后求答案的时候只要莫比乌斯反演回来就好了。)

  注意一下我的代码中对于 (u,v) 有祖先关系的是分开考虑的。

  先点分治。

  对于一个点分中心 x ,我们把答案分两部分考虑。

  1. 在子树 x 中满足 LCA(u,v) = x 的 (u,v) 对于答案的贡献。

  2. u,v 其中一个点在子树 x 中,另一个不在。

  第一部分非常好求,不加赘述。

  第二部分,我们考虑定义一个阀值 S ,我们预处理出 Smod[i][j] 表示 子树 x 中,到 x 的距离 mod i = j 的点的个数。这样,我们就可以 O(1) 得到 在子树 x 中,到达 x 的某一个祖先的距离为 i 的倍数的点的个数 。这样,我们就可以在 $O(nS)$ 的复杂度内求出对于 $ans[i](i\leq S)$ 的贡献。 对于 i>S 的,我们可以直接暴力计算 在子树 x 中,到达 x 的某一个祖先的距离为 i 的倍数的点的个数 ,复杂度为 $O(n^2/S)$ 。取 $S = O(\sqrt{n})$ 最优。故处理一个点分中心的复杂度为 $O(n\sqrt{n})$ (假设当前连通块大小为 n)。

  所以总的时间复杂度为 $O(n\sqrt{n})$ 。  

代码

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
LL read(){
	LL x=0,f=0;
	char ch=getchar();
	while (!isdigit(ch))
		f|=ch=='-',ch=getchar();
	while (isdigit(ch))
		x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
	return f?-x:x;
}
const int N=200005,M=500;
int n;
vector <int> e[N];
LL ans[N],ans2[N];
int depth[N],fa[N];
void dfs(int x,int pre,int d){
	fa[x]=pre,depth[x]=d;
	for (auto y : e[x])
		if (y!=pre)
			dfs(y,x,d+1);
}
int vis[N],size[N],Size;
int Maxsize[N],rt;
void get_root(int x,int pre){
	size[x]=1,Maxsize[x]=0;
	for (auto y : e[x])
		if (y!=pre&&!vis[y]){
			get_root(y,x);
			size[x]+=size[y];
			Maxsize[x]=max(Maxsize[x],size[y]);
		}
	Maxsize[x]=max(Maxsize[x],Size-size[x]);
	if (!rt||Maxsize[rt]>Maxsize[x])
		rt=x;
}
vector <int> d[N];
void get_size(int x,int pre){
	size[x]=1;
	for (auto y : e[x])
		if (y!=pre&&!vis[y])
			get_size(y,x),size[x]+=size[y];
}
void getd(int x,int pre,int d,vector <int> &v){
	while (d>=(int)v.size())
		v.push_back(0);
	v[d]++;
	for (auto y : e[x])
		if (y!=pre&&!vis[y])
			getd(y,x,d+1,v);
}
LL S[N];
LL Smod[M][M];
void solve(int x){
	rt=0;
	get_root(x,0);
	assert(rt!=0);
	vis[x=rt]=1;
	for (int i=0;i<=Size;i++)
		S[i]=0;
	int Mx=0;
	for (auto y : e[x])
		if (!vis[y]){
			get_size(y,0);
			if (depth[y]<depth[x])
				continue;
			d[y].clear();
			getd(y,0,1,d[y]);
			int t=d[y].size()-1;
			for (int i=1;i<=t;i++){
				for (int j=i<<1;j<=t;j+=i)
					d[y][i]+=d[y][j];
				ans[i]+=(LL)d[y][i]*S[i];
				S[i]+=d[y][i];
			}
			Mx=max(Mx,t);
			d[y].clear();
		}
	for (int i=Mx;i>=1;i--)
		for (int j=i<<1;j<=Mx;j+=i)
			S[i]-=S[j];
	S[0]++;
	int base=(int)(0.4*sqrt(Mx)+0.5);
	for (int i=1;i<=base;i++){
		for (int j=0;j<i;j++)
			Smod[i][j]=0;
		for (int j=0;j<=Mx;j++)
			Smod[i][j%i]+=S[j];
	}
	for (int f=fa[x],pre=x;f&&!vis[f];pre=f,f=fa[f]){
		d[f].clear();
		for (auto y : e[f])
			if (!vis[y]&&y!=pre&&y!=fa[f])
				getd(y,f,1,d[f]);
		int t=d[f].size()-1;
		for (int i=1;i<=t;i++){
			for (int j=i<<1;j<=t;j+=i)
				d[f][i]+=d[f][j];
			int tmp=(i-(depth[x]-depth[f])%i)%i;
			if (i<=base)
				ans[i]+=(LL)d[f][i]*Smod[i][tmp];
			else
				for (int j=tmp;j<=Mx;j+=i)	
					ans[i]+=(LL)d[f][i]*S[j];
		}
		d[f].clear();
	}
	for (auto y : e[x])
		if (!vis[y])
			Size=size[y],solve(y);
}
int main(){
	n=read();
	for (int i=2;i<=n;i++){
		int x=read();
		e[i].push_back(x);
		e[x].push_back(i);
	}
	dfs(1,0,0);
	Size=n;
	solve(1);
	for (int i=n;i>=1;i--)
		for (int j=i<<1;j<=n;j+=i)
			ans[i]-=ans[j];
	for (int i=1;i<=n;i++)
		ans2[depth[i]]++;
	for (int i=n;i>=1;i--)
		ans2[i]+=ans2[i+1];
	for (int i=1;i<n;i++)
		printf("%lld\n",ans[i]+ans2[i]);
	return 0;
}

  

posted @ 2018-12-19 16:19  zzd233  阅读(375)  评论(0编辑  收藏  举报