BZOJ3451 Tyvj1953 Normal 点分治 多项式 FFT
原文链接https://www.cnblogs.com/zhouzhendong/p/BZOJ3451.html
题目传送门 - BZOJ3451
题意
给定一棵有 $n$ 个节点的树,在树上随机点分治,问消耗时间的期望。
计算点分治耗时由如下函数给出:
Time = 0 Solve( T ){ Time += |T| if ( |T| = 1 ) then return ; x = 一个随机节点 in T for y in {与 x 直接连边的节点 in T} do Solve( SubTree y ) }
$n\leq 3\times 10^4$
题解
考虑最终在点分树上,如果 $y$ 是 $x$ 的祖先,那么就会产生一点贡献。我们把它记为:数对 $(x,y)$ 产生了一点贡献。
回到原树,如果数对 $(x,y)$ 能产生贡献,那么 $y$ 一定是 $x$ 到 $y$ 路径上面最先被选中的。因为, $x$ 显然不能先选,而如果选择了 $x$ 到 $y$ 之间的节点,那么 $x$ 和 $y$ 在点分树中就被分到了两个子树中,不能产生贡献。
由于是等概率随机选点,所以对于任何两个点 $(x,y)$ ,$y$ 最先被选中的概率显然是 $\cfrac{1}{dis(x,y)}$ 。
所以,答案被转化为:
$$\sum_{i=1}^{n}\sum_{j=1}^{n}\cfrac{1}{dis(x,y)}$$
于是我们就可以 点分治 + FFT 了。
具体地:
我们需要得到每一个长度的路径条数。
对于每一个点分中心,我们先分治处理所有子树。
然后,dfs 遍历所有子树,对于每一个子树得到子树内的深度值的情况。注意,用 vector 存。
然后,任意两个子树的信息就可以合并了。
但是暴力合并显然是不对的,$n^2$ 复杂度肯定超时。(下面把子树深度数组当作多项式)
考虑将子树按照最深深度排序,从深度小到深度大,通过多项式前缀和一下,FFT 优化多项式乘法即可。
再具体??看代码吧……
时间复杂度 $O(n\log^2 n)$ 。
代码
#include <bits/stdc++.h> using namespace std; int read(){ int x=0; char ch=getchar(); while (!isdigit(ch)) ch=getchar(); while (isdigit(ch)) x=(x<<1)+(x<<3)+ch-48,ch=getchar(); return x; } const int N=30005; struct Gragh{ static const int M=N*2; int cnt,y[M],nxt[M],fst[N]; void clear(){ cnt=0; memset(fst,0,sizeof fst); } void add(int a,int b){ y[++cnt]=b,nxt[cnt]=fst[a],fst[a]=cnt; } }g; int Pow(int x,int y,int mod){ int ans=1; for (;y;y>>=1,x=1LL*x*x%mod) if (y&1) ans=1LL*ans*x%mod; return ans; } namespace FaFaTa{ static const int N=1<<15,mod=998244353; int n,d; int w[N],R[N],A[N],B[N]; vector <int> res; void FFT(int a[],int n){ for (int i=0;i<n;i++) if (R[i]<i) swap(a[i],a[R[i]]); for (int d=1,t=n>>1;d<n;d<<=1,t>>=1) for (int i=0;i<n;i+=(d<<1)) for (int j=0;j<d;j++){ int tmp=1LL*w[t*j]*a[i+j+d]%mod; a[i+j+d]=(a[i+j]-tmp+mod)%mod; a[i+j]=(a[i+j]+tmp)%mod; } } void FFT_Mul(vector <int> &a,vector <int> &b){ int la=a.size(),lb=b.size(); for (n=1,d=0;n<la+lb-1;n<<=1,d++); for (int i=0;i<n;i++) R[i]=(R[i>>1]>>1)|((i&1)<<(d-1)); w[0]=1,w[1]=Pow(3,(mod-1)/n,mod); for (int i=2;i<n;i++) w[i]=1LL*w[i-1]*w[1]%mod; for (int i=0;i<n;i++) A[i]=i<la?a[i]:0; for (int i=0;i<n;i++) B[i]=i<lb?b[i]:0; FFT(A,n),FFT(B,n); for (int i=0;i<n;i++) A[i]=1LL*A[i]*B[i]%mod; w[1]=Pow(w[1],mod-2,mod); for (int i=2;i<n;i++) w[i]=1LL*w[i-1]*w[1]%mod; FFT(A,n); int inv=Pow(n,mod-2,mod); for (int i=0;i<n;i++) A[i]=1LL*A[i]*inv%mod; res.clear(); for (int i=0;i<n;i++) res.push_back(A[i]); } } int n,Time; int ans[N],vis[N],size[N],Max[N],root; int id[N],id_cnt,dcnt[N]; vector <int> depth[N]; void Get_root(int x,int pre){ id[++id_cnt]=x; size[x]=1,Max[x]=0; for (int i=g.fst[x];i;i=g.nxt[i]) if (g.y[i]!=pre&&!vis[g.y[i]]){ Get_root(g.y[i],x); size[x]+=size[g.y[i]]; Max[x]=max(Max[x],size[g.y[i]]); } } void Get_depth(int rt,int y,int x,int pre,int d){ if (depth[y].size()<=0.1+d) depth[y].push_back(0); depth[y][d]++; for (int i=g.fst[x];i;i=g.nxt[i]) if (g.y[i]!=pre&&vis[rt]<vis[g.y[i]]) Get_depth(rt,y,g.y[i],x,d+1); } bool cmp(int a,int b){ return depth[a].size()<depth[b].size(); } void solve(int root){ id_cnt=0; Get_root(root,0); int n=size[root],x=root; for (int i=1;i<=id_cnt;i++){ int y=id[i]; Max[y]=max(Max[y],n-size[y]); if (Max[y]<Max[x]) x=y; } vis[x]=++Time; for (int i=g.fst[x];i;i=g.nxt[i]) if (!vis[g.y[i]]) solve(g.y[i]); id_cnt=0; for (int i=g.fst[x];i;i=g.nxt[i]) if (vis[g.y[i]]>vis[x]){ int y=id[++id_cnt]=g.y[i]; depth[y].clear(); depth[y].push_back(0); Get_depth(x,y,y,x,1); } sort(id+1,id+id_cnt+1,cmp); depth[id[0]=0].clear(); depth[0].push_back(1); for (int i=1;i<=id_cnt;i++){ FaFaTa :: FFT_Mul(depth[id[i-1]],depth[id[i]]); for (int j=0;j<FaFaTa :: n;j++) ans[j]+=FaFaTa :: res[j]; for (int j=0;j<depth[id[i-1]].size();j++) depth[id[i]][j]+=depth[id[i-1]][j]; } } int main(){ n=read(); g.clear(); for (int i=1,a,b;i<n;i++){ a=read()+1,b=read()+1; g.add(a,b); g.add(b,a); } memset(ans,0,sizeof ans); memset(vis,0,sizeof vis); Time=0; solve(1); long double tot=n; for (int i=1;i<n;i++) tot+=((long double)ans[i])*2/(i+1); printf("%.4Lf",tot); return 0; }