【Codechef】—Prime Distance On Tree(点分治+FFT)
水题
看到询问路径就想到点分治
令表示当前中心,深度为的点的个数
表示长度为的路径的个数
则
也就是自己和自己卷积
计算出后统计一下所有质数就可以了
#include<bits/stdc++.h>
#define int long long
using namespace std;
inline int read(){
int ans=0;
char ch=getchar();
while(!isdigit(ch))ch=getchar();
while(isdigit(ch))ans=(ans<<3)+(ans<<1)+(ch^48),ch=getchar();
return ans;
}
typedef long long ll;
const double pi=acos(-1.0);
const int N=2e5+5;
int n,pri[N],tot=0,msiz[N],siz[N],vis[N],rt,all,tmp[N],lim=1,tim=0,pos[N],maxn;
bool isp[N];
vector<int>e[N>>2];
ll ans=0;
inline void init(){
isp[1]=1;
for(int i=2;i<=n;++i){
if(!isp[i])pri[++tot]=i;
for(int j=1;j<=tot&&i*pri[j]<=n;++j){
isp[i*pri[j]]=1;
if(i%pri[j]==0)break;
}
}
}
struct plx{
double x,y;
plx(double _x=0,double _y=0):x(_x),y(_y){}
friend inline plx operator +(const plx &a,const plx &b){
return plx(a.x+b.x,a.y+b.y);
}
friend inline plx operator -(const plx &a,const plx &b){
return plx(a.x-b.x,a.y-b.y);
}
friend inline plx operator *(const plx &a,const plx &b){
return plx(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);
}
}cnt[N];
inline void fft(plx f[],int kd){
for(int i=0;i<lim;i++)if(i<pos[i])swap(f[i],f[pos[i]]);
for(int mid=1;mid<lim;mid<<=1){
plx now=plx(cos(pi/mid),kd*sin(pi/mid));
for(int i=0;i<lim;i+=(mid<<1)){
plx w=plx(1,0);
for(int j=0;j<mid;j++,w=w*now){
plx a0=f[i+j],a1=w*f[i+j+mid];
f[i+j]=a0+a1,f[i+j+mid]=a0-a1;
}
}
}
if(kd==-1)for(int i=0;i<lim;i++)f[i].x/=lim;
}
void getroot(int p,int fa){
msiz[p]=siz[p]=1;
for(int i=0;i<e[p].size();++i){
int v=e[p][i];
if(v==fa||vis[v])continue;
getroot(v,p),siz[p]+=siz[v],msiz[p]=max(msiz[p],siz[v]);
}
msiz[p]=max(msiz[p],all-siz[p]);
if(msiz[p]<msiz[rt])rt=p;
}
void getdis(int p,int fa,int delt){
cnt[delt].x+=1;
maxn=max(maxn,delt);
for(int i=0;i<e[p].size();++i){
int v=e[p][i];
if(vis[v]||v==fa)continue;
getdis(v,p,delt+1);
}
}
inline void calc(int p,int delt,int type){
maxn=0,getdis(p,0,delt),lim=1,tim=0;
while(lim<=maxn*2)++tim,lim<<=1;
for(int i=0;i<lim;++i)pos[i]=(pos[i>>1]>>1)|((i&1)<<(tim-1));
int sum=(int)-cnt[1].x;
fft(cnt,1);
for(int i=0;i<lim;++i)cnt[i]=cnt[i]*cnt[i];
fft(cnt,-1);
for(int i=1;i<=tot;++i){
if(pri[i]>lim)break;
sum+=(int)(cnt[pri[i]].x+0.5);
}
ans+=(ll)sum/2*type;
for(int i=0;i<lim;++i)cnt[i].x=cnt[i].y=0;
}
void solve(int p){
calc(p,0,1),vis[p]=1;
for(int i=0;i<e[p].size();++i){
int v=e[p][i];
if(vis[v])continue;
all=siz[v],rt=0,calc(v,1,-1),getroot(v,0),solve(rt);
}
}
signed main(){
n=read(),init();
for(int i=1,u,v;i<n;++i)u=read(),v=read(),e[u].push_back(v),e[v].push_back(u);
all=msiz[rt=0]=n,getroot(1,0),solve(rt);
printf("%.8lf",(double)ans*2.0/(double)((double)n*(double)(n-1)));
return 0;
}