【Codechef】—Prime Distance On Tree(点分治+FFT)

传送门

fftfft水题

看到询问路径就想到点分治
cnt[i]cnt[i]表示当前中心,深度为ii的点的个数
ans[i]ans[i]表示长度为ii的路径的个数

ans[i]=j=0icnt[j]cnt[ij]ans[i]=\sum_{j=0}^{i}cnt[j]*cnt[i-j]
也就是cntcnt自己和自己卷积

计算出ansans后统计一下所有质数就可以了

#include<bits/stdc++.h>
#define int long long
using namespace std;
inline int read(){
	int ans=0;
	char ch=getchar();
	while(!isdigit(ch))ch=getchar();
	while(isdigit(ch))ans=(ans<<3)+(ans<<1)+(ch^48),ch=getchar();
	return ans;
}
typedef long long ll;
const double pi=acos(-1.0);
const int N=2e5+5;
int n,pri[N],tot=0,msiz[N],siz[N],vis[N],rt,all,tmp[N],lim=1,tim=0,pos[N],maxn;
bool isp[N];
vector<int>e[N>>2];
ll ans=0;
inline void init(){
	isp[1]=1;
	for(int i=2;i<=n;++i){
		if(!isp[i])pri[++tot]=i;
		for(int j=1;j<=tot&&i*pri[j]<=n;++j){
			isp[i*pri[j]]=1;
			if(i%pri[j]==0)break;
		}
	}
}
struct plx{
	double x,y;
	plx(double _x=0,double _y=0):x(_x),y(_y){}
	friend inline plx operator +(const plx &a,const plx &b){
		return plx(a.x+b.x,a.y+b.y);
	}
	friend inline plx operator -(const plx &a,const plx &b){
		return plx(a.x-b.x,a.y-b.y);
	}
	friend inline plx operator *(const plx &a,const plx &b){
		return plx(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);
	}
}cnt[N];
inline void fft(plx f[],int kd){
	for(int i=0;i<lim;i++)if(i<pos[i])swap(f[i],f[pos[i]]);
	for(int mid=1;mid<lim;mid<<=1){
		plx now=plx(cos(pi/mid),kd*sin(pi/mid));
		for(int i=0;i<lim;i+=(mid<<1)){
			plx w=plx(1,0);
			for(int j=0;j<mid;j++,w=w*now){
				plx a0=f[i+j],a1=w*f[i+j+mid];
				f[i+j]=a0+a1,f[i+j+mid]=a0-a1;
			}
		}
	}
	if(kd==-1)for(int i=0;i<lim;i++)f[i].x/=lim;
}
void getroot(int p,int fa){
	msiz[p]=siz[p]=1;
	for(int i=0;i<e[p].size();++i){
		int v=e[p][i];
		if(v==fa||vis[v])continue;
		getroot(v,p),siz[p]+=siz[v],msiz[p]=max(msiz[p],siz[v]);
	}
	msiz[p]=max(msiz[p],all-siz[p]);
	if(msiz[p]<msiz[rt])rt=p;
}
void getdis(int p,int fa,int delt){
	cnt[delt].x+=1;
	maxn=max(maxn,delt);
	for(int i=0;i<e[p].size();++i){
		int v=e[p][i];
		if(vis[v]||v==fa)continue;
		getdis(v,p,delt+1);
	}
}
inline void calc(int p,int delt,int type){
	maxn=0,getdis(p,0,delt),lim=1,tim=0;
	while(lim<=maxn*2)++tim,lim<<=1;
	for(int i=0;i<lim;++i)pos[i]=(pos[i>>1]>>1)|((i&1)<<(tim-1));
	int sum=(int)-cnt[1].x;
	fft(cnt,1);
	for(int i=0;i<lim;++i)cnt[i]=cnt[i]*cnt[i];
	fft(cnt,-1);
	for(int i=1;i<=tot;++i){
		if(pri[i]>lim)break;
		sum+=(int)(cnt[pri[i]].x+0.5);
	}
	ans+=(ll)sum/2*type;
	for(int i=0;i<lim;++i)cnt[i].x=cnt[i].y=0;
}
void solve(int p){
	calc(p,0,1),vis[p]=1;
	for(int i=0;i<e[p].size();++i){
		int v=e[p][i];
		if(vis[v])continue;
		all=siz[v],rt=0,calc(v,1,-1),getroot(v,0),solve(rt);
	}
}
signed main(){
	n=read(),init();
	for(int i=1,u,v;i<n;++i)u=read(),v=read(),e[u].push_back(v),e[v].push_back(u);
	all=msiz[rt=0]=n,getroot(1,0),solve(rt);
	printf("%.8lf",(double)ans*2.0/(double)((double)n*(double)(n-1)));
	return 0;
}
posted @ 2019-03-14 16:07  Stargazer_cykoi  阅读(120)  评论(0编辑  收藏  举报