【BZOJ3451】【Tyvj1953】—Normal(点分治+NTT)
Description
某天WJMZBMR学习了一个神奇的算法:树的点分治!
这个算法的核心是这样的:
消耗时间=0
Solve(树 a)
消耗时间 += a 的 大小
如果 a 中 只有 1 个点
退出
否则在a中选一个点x,在a中删除点x
那么a变成了几个小一点的树,对每个小树递归调用Solve
我们注意到的这个算法的时间复杂度跟选择的点x是密切相关的。
如果x是树的重心,那么时间复杂度就是O(nlogn)
但是由于WJMZBMR比较傻逼,他决定随机在a中选择一个点作为x!
Sevenkplus告诉他这样做的最坏复杂度是O(n^2)
但是WJMZBMR就是不信>_<。。。
于是Sevenkplus花了几分钟写了一个程序证明了这一点。。。你也试试看吧_
现在给你一颗树,你能告诉WJMZBMR他的傻逼算法需要的期望消耗时间吗?(消耗时间按在Solve里面的那个为标准)
Input
第一行一个整数n,表示树的大小
接下来n-1行每行两个数a,b,表示a和b之间有一条边
注意点是从0开始标号的
Output
一行一个浮点数表示答案
四舍五入到小数点后4位
如果害怕精度跪建议用long double或者extended
Sample Input
3
0 1
1 2
Sample Output
5.6667
HINT
n<=30000
考虑计算一个点对的贡献
那显然只有为分治中心的时候对它才有贡献
考虑如果分治中心不在路径~内那显然不会产生任何影响
而如果一次中心把和隔开了那就没有贡献了
那也就是说对的时候是分治中心是,而不是~路径上的其他点
也就是说概率是
那答案就是
而也就是可以求出每个距离出现了多少次,除以距离就可以了
而这个是可以通过点分治+在的时间求出来
复杂度
#include<bits/stdc++.h>
using namespace std;
const int RLEN=1<<20|1;
#define ll long long
inline char gc(){
static char ibuf[RLEN],*ib,*ob;
(ib==ob)&&(ob=(ib=ibuf)+fread(ibuf,1,RLEN,stdin));
return (ib==ob)?EOF:*ib++;
}
#define gc getchar
inline int read(){
char ch=gc();
int res=0,f=1;
while(!isdigit(ch))f^=ch=='-',ch=gc();
while(isdigit(ch))res=(res+(res<<2)<<1)+(ch^48),ch=gc();
return f?res:-res;
}
const int mod=998244353,g=3;
const int M=120005,N=30005;
inline int add(int a,int b){
return a+b>=mod?a+b-mod:a+b;
}
inline int dec(int a,int b){
return a>=b?a-b:a-b+mod;
}
inline int mul(int a,int b){
return 1ll*a*b>=mod?1ll*a*b%mod:a*b;
}
inline int ksm(int a,int b,int res=1){
for(;b;b>>=1,a=mul(a,a))(b&1)?(res=mul(res,a)):0;return res;
}
int rev[M],A[M],lim,tim;
inline void ntt(int *f,int kd){
for(int i=0;i<lim;i++)if(i<rev[i])swap(f[i],f[rev[i]]);
int bas=(kd==1)?g:((mod+1)/3);
for(int mid=1;mid<lim;mid<<=1){
int now=ksm(bas,(mod-1)/(mid<<1));
for(int i=0;i<lim;i+=(mid<<1)){
int w=1;
for(int j=0;j<mid;j++,w=mul(w,now)){
int a0=f[i+j],a1=mul(w,f[i+j+mid]);
f[i+j]=add(a0,a1),f[i+j+mid]=dec(a0,a1);
}
}
}
if(kd==-1)for(int i=0,inv=ksm(lim,mod-2);i<lim;i++)f[i]=mul(f[i],inv);
}
int dep[N],val[N],ans[N],maxn,rt,mx,siz[N],son[N],adj[N],nxt[N<<1],to[N<<1],vis[N],cnt,tot;
inline void addedge(int u,int v){
nxt[++cnt]=adj[u],adj[u]=cnt,to[cnt]=v;
}
void getrt(int u,int fa){
siz[u]=1,son[u]=0;
for(int e=adj[u];e;e=nxt[e]){
int v=to[e];
if(vis[v]||v==fa)continue;
getrt(v,u),siz[u]+=siz[v];
if(siz[v]>son[u])son[u]=siz[v];
}
son[u]=max(son[u],maxn-siz[u]);
if(son[u]<son[rt])rt=u;
}
void getdep(int u,int fa){
val[++tot]=dep[u],mx=max(mx,dep[u]);
for(int e=adj[u];e;e=nxt[e]){
int v=to[e];
if(vis[v]||v==fa)continue;
dep[v]=dep[u]+1;
getdep(v,u);
}
}
void calc(int u,int l,int f){
dep[u]=l,mx=tot=0;
getdep(u,0);
lim=1,tim=0;
while(lim<=2*mx)lim<<=1,tim++;
for(int i=0;i<lim;i++)A[i]=0,rev[i]=(rev[i>>1]>>1)|((i&1)<<(tim-1));
for(int i=1;i<=tot;i++)A[val[i]]++;
ntt(A,1);
for(int i=0;i<lim;i++)A[i]=mul(A[i],A[i]);
ntt(A,-1);
for(int i=0;i<=2*mx;i++)ans[i]+=f*A[i];
}
void solve(int u){
vis[u]=1;
calc(u,0,1);
for(int e=adj[u];e;e=nxt[e]){
int v=to[e];
if(vis[v])continue;
calc(v,1,-1),maxn=siz[v];
getrt(v,rt=0);
solve(rt);
}
}
int n;
int main(){
maxn=son[0]=n=read();
for(int i=1;i<n;i++){
int u=read()+1,v=read()+1;
addedge(u,v),addedge(v,u);
}
getrt(1,0);
solve(rt);
long double res=0;
for(int i=0;i<n;i++)res+=(long double)ans[i]/(i+1);
printf("%.4Lf",res);
}