最裸的点分治+fft,调了好久,太菜了。。。。
#include<iostream> #include<cstring> #include<cstdio> #include<cmath> #include<algorithm> using namespace std; typedef long long ll; const int maxn=200010,inf=1e9; const double pi=acos(-1); int f[maxn],t,last[maxn],pre[maxn],other[maxn],siz[maxn],vis[maxn]; int mi,root,rev[maxn],dep,N,n,p[maxn],tot,is[maxn]; ll sum[maxn],c[maxn],cnt[maxn]; void add(int x,int y){++t;pre[t]=last[x];last[x]=t;other[t]=y;} void getroot(int x,int fa,int ac){ f[x]=0; for(int i=last[x];i;i=pre[i]){ int v=other[i]; if(vis[v]||v==fa)continue; getroot(v,x,ac); f[x]=max(f[x],siz[v]); } f[x]=max(siz[ac]-siz[x],f[x]);//注意这里是siz[ac]而不是n; if(f[x]<mi){mi=f[x];root=x;} } void dfs(int x,int fa,int d){ c[d]++; for(int i=last[x];i;i=pre[i]){ int v=other[i]; if(vis[v]||v==fa)continue; dfs(v,x,d+1); } } struct cp{ double r,i; cp operator+(cp&t){cp tp;tp.r=r+t.r;tp.i=i+t.i;return tp;} cp operator-(cp&t){cp tp;tp.r=r-t.r;tp.i=i-t.i;return tp;} cp operator*(cp&t){cp tp;tp.r=r*t.r-i*t.i;tp.i=t.r*i+t.i*r;return tp;} }A[maxn],B[maxn],tmp[maxn],wn,w,x,y; void fft(cp a[],int n,int flag){ for(int i=0;i<n;++i){ rev[i]=rev[i>>1]>>1; if(i&1)rev[i]|=(n>>1); } for(int i=0;i<n;++i)tmp[i]=a[rev[i]]; for(int i=0;i<n;++i)a[i]=tmp[i]; for(int i=2;i<=n;i<<=1){ wn.r=cos(2*pi/i);wn.i=flag*sin(2*pi/i); for(int j=0;j<n;j+=i){ w.r=1;w.i=0; for(int k=j;k<j+i/2;++k){ x=a[k];y=a[k+i/2]*w; a[k]=x+y;a[k+i/2]=x-y; w=w*wn; } } } if(flag==-1)for(int i=0;i<n;++i)a[i].r/=n; } void Siz(int x,int fa){ siz[x]=1; for(int i=last[x];i;i=pre[i]){ int v=other[i]; if(v==fa||vis[v])continue; Siz(v,x); siz[x]+=siz[v]; } } void calc(ll a[],int n,int flag){ for(int i=0;i<n;++i)A[i].r=a[i],A[i].i=0; for(int i=0;i<n;++i)B[i].r=a[i],B[i].i=0; fft(A,n,1); fft(B,n,1); for(int i=0;i<n;++i)A[i]=A[i]*B[i]; fft(A,n,-1); for(int i=0;i<n;++i)sum[i]+=flag*(ll)(A[i].r+0.3); } void solve(int x){ mi=1e9; ll res=0; Siz(x,0); for(N=1;N<=siz[x];N<<=1); for(int i=0;i<N;++i)cnt[i]=0; cnt[0]=1; for(int i=last[x];i;i=pre[i]){ int v=other[i]; if(vis[v])continue; for(N=1;N<=2*siz[v];N<<=1); for(int j=0;j<N;++j)c[j]=0; dfs(v,x,1); calc(c,N,-1); for(int j=0;j<N;++j)cnt[j]+=c[j]; } for(N=1;N<=siz[x];N<<=1); calc(cnt,N,1); /*for(int i=0;i<n;++i)cout<<A[i].r<<' '; cout<<endl;*/ sum[0]=0; } void divont(int x){ mi=1e9; Siz(x,0); getroot(x,0,x); int u=root; //cout<<u<<endl; solve(u); vis[u]=1; for(int i=last[u];i;i=pre[i]){ int v=other[i]; if(!vis[v])divont(v); } } int main(){ cin>>n; int x,y; for(int i=1;i<n;++i){ scanf("%d%d",&x,&y); add(x,y);add(y,x); } divont(1); for(int i=2;i<=50010;++i){ if(!is[i]){p[++tot]=i;} for(int j=1;j<=tot&&i*p[j]<=50010;++j){ is[i*p[j]]=1; if(i%p[j]==0)break; } } double mu=(double)n*(n-1)/2,res=0; for(int i=1;i<=tot&&p[i]<=n;++i){ res+=sum[p[i]]; } res/=2; printf("%.7lf",double(res)/double(mu)); return 0; }