BZOJ3451. Tyvj1953 Normal
考虑每个点 $i$ 对答案的贡献
当删去一个节点 $j$ 的时候, $i$ 会对 $j$ 产生 $1$ 的贡献当且仅当 $i,j$ 这条链上的所有点中,$j$ 是第一个删除的节点
显然链上每个节点第一个被删除的概率是一样的
所以点对 $i,j$ 的贡献就是 $\frac{1}{dis(i,j)}$,其中 $dis(i,i)=1$
那么答案就相当于 $\sum_{i}\sum_{j}\frac{1}{dis(i,j)}$
发现可以转化为求,对于每个值 $k$ ,$dis=k$ 的点对的数量
显然直接点分治
对于每一个分治节点,统计跨过它的各种长度的路径数量
设 $A[k]$ 表示当前节点 $x$ 的点分子树内所有到 $x$ 的路径长度恰好为 $k$ 的路径数量,设 $B[k]$ 为跨过 $x$ 的两点路径长度恰好为 $k$ 的路径的数量
则有 $B[k]=\sum_{j=0}^{k}A[j]A[k-j]$,是卷积的形式,可以 $FFT$ 优化
这样没有考虑两点在 $x$ 的同一儿子 $v$ 子树内的不合法情况,但是可以直接用同样的方法算出对于 $v$ 的 $B[]$,减一下就行了
这样总复杂度 $O(nlog_{n}^{2})$
只要熟悉点分治和 $FFT$ ,代码不难理解
#include<iostream> #include<cstdio> #include<algorithm> #include<cstring> #include<cmath> using namespace std; typedef long long ll; typedef long double ldb; inline int read() { int x=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9') { if(ch=='-') f=-1; ch=getchar(); } while(ch>='0'&&ch<='9') { x=(x<<1)+(x<<3)+(ch^48); ch=getchar(); } return x*f; } const int N=4e5+7; const ldb pi=acos(-1.0); struct CP { ldb x,y; CP (ldb xx=0,ldb yy=0) { x=xx,y=yy; } inline CP operator + (const CP &tmp) const { return CP(x+tmp.x,y+tmp.y); } inline CP operator - (const CP &tmp) const { return CP(x-tmp.x,y-tmp.y); } inline CP operator * (const CP &tmp) const { return CP(x*tmp.x-y*tmp.y,x*tmp.y+y*tmp.x); } }A[N]; int n,p[N]; void FFT(CP *A,int len,int type) { for(int i=0;i<len;i++) if(i<p[i]) swap(A[i],A[p[i]]); for(int mid=1;mid<len;mid<<=1) { CP wn(cos(pi/mid),type*sin(pi/mid)); for(int R=mid<<1,j=0;j<len;j+=R) { CP w(1,0); for(int k=0;k<mid;k++,w=w*wn) { CP x=A[j+k],y=w*A[j+mid+k]; A[j+k]=x+y; A[j+mid+k]=x-y; } } } } int fir[N],from[N<<1],to[N<<1],cntt; inline void add(int a,int b) { from[++cntt]=fir[a]; fir[a]=cntt; to[cntt]=b; } int sz[N],mx[N],rt,tot; bool vis[N]; void find_rt(int x,int fa) { sz[x]=1; mx[x]=0; for(int i=fir[x];i;i=from[i]) { int &v=to[i]; if(v==fa||vis[v]) continue; find_rt(v,x); sz[x]+=sz[v]; mx[x]=max(mx[x],sz[v]); } mx[x]=max(mx[x],tot-sz[x]); if(mx[x]<mx[rt]) rt=x; } int st[N],Top; void dfs(int x,int fa,int dis) { st[++Top]=dis; for(int i=fir[x];i;i=from[i]) if(to[i]!=fa&&!vis[to[i]]) dfs(to[i],x,dis+1); } ll ans[N]; void calc(int type) { int mx=0,len=1,tot=0; for(int i=1;i<=Top;i++) mx=max(mx,st[i]); while(len<=2*mx) len<<=1,tot++; for(int i=0;i<=len;i++) A[i]=CP(0,0); for(int i=1;i<=Top;i++) A[st[i]].x++; for(int i=0;i<len;i++) p[i]=(p[i>>1]>>1)|((i&1)<<(tot-1)); FFT(A,len,1); for(int i=0;i<=len;i++) A[i]=A[i]*A[i]; FFT(A,len,-1); for(int i=0;i<=mx*2;i++) ans[i]+=1ll*type*ll(A[i].x/len+0.5); } void solve(int x) { vis[x]=1; Top=0; dfs(x,0,0); calc(1); for(int i=fir[x];i;i=from[i]) { int &v=to[i]; if(vis[v]) continue; Top=0; dfs(v,x,1); calc(-1); rt=0; tot=sz[v]; find_rt(v,x); solve(rt); } } ldb Ans=0; int main() { n=read(); int a,b; for(int i=1;i<n;i++) { a=read()+1,b=read()+1; add(a,b),add(b,a); } tot=n; mx[0]=2333333; find_rt(1,0); solve(rt); for(int i=0;i<n;i++) Ans+=(ldb)ans[i]/(i+1); printf("%.4Lf\n",Ans); return 0; }