【BZOJ3451】【Tyvj1953】—Normal(点分治+NTT)

传送门

Description

某天WJMZBMR学习了一个神奇的算法:树的点分治!
这个算法的核心是这样的:
消耗时间=0
Solve(树 a)
消耗时间 += a 的 大小
如果 a 中 只有 1 个点
退出
否则在a中选一个点x,在a中删除点x
那么a变成了几个小一点的树,对每个小树递归调用Solve
我们注意到的这个算法的时间复杂度跟选择的点x是密切相关的。
如果x是树的重心,那么时间复杂度就是O(nlogn)
但是由于WJMZBMR比较傻逼,他决定随机在a中选择一个点作为x!
Sevenkplus告诉他这样做的最坏复杂度是O(n^2)
但是WJMZBMR就是不信>_<。。。
于是Sevenkplus花了几分钟写了一个程序证明了这一点。。。你也试试看吧_
现在给你一颗树,你能告诉WJMZBMR他的傻逼算法需要的期望消耗时间吗?(消耗时间按在Solve里面的那个为标准)

Input

第一行一个整数n,表示树的大小
接下来n-1行每行两个数a,b,表示a和b之间有一条边
注意点是从0开始标号的

Output
一行一个浮点数表示答案
四舍五入到小数点后4位
如果害怕精度跪建议用long double或者extended

Sample Input

3

0 1

1 2

Sample Output

5.6667

HINT

n<=30000

考虑计算一个点iijj的贡献
那显然只有jj为分治中心的时候ii对它才有贡献
考虑如果分治中心不在路径ii~jj内那显然不会产生任何影响
而如果一次中心把iijj隔开了那就没有贡献了
那也就是说iijj的时候是分治中心是jj,而不是ii~jj路径上的其他点
也就是说概率是1dis(i,j)\frac 1 {dis(i,j)}
那答案就是i=1nj=1n1dis(i,j)\sum_{i=1}^{n}\sum_{j=1}^{n}\frac{1}{dis(i,j)}

而也就是可以求出每个距离出现了多少次,除以距离就可以了
而这个是可以通过点分治+NTTNTTO(nlog2n)O(nlog^2n)的时间求出来

复杂度O(nlog2n)O(nlog^2n)

#include<bits/stdc++.h>
using namespace std;
const int RLEN=1<<20|1;
#define ll long long
inline char gc(){
	static char ibuf[RLEN],*ib,*ob;
	(ib==ob)&&(ob=(ib=ibuf)+fread(ibuf,1,RLEN,stdin));
	return (ib==ob)?EOF:*ib++;
}
#define gc getchar
inline int read(){
	char ch=gc();
	int res=0,f=1;
	while(!isdigit(ch))f^=ch=='-',ch=gc();
	while(isdigit(ch))res=(res+(res<<2)<<1)+(ch^48),ch=gc();
	return f?res:-res;
}
const int mod=998244353,g=3;
const int M=120005,N=30005;
inline int add(int a,int b){
	return a+b>=mod?a+b-mod:a+b;
}
inline int dec(int a,int b){
	return a>=b?a-b:a-b+mod;
}
inline int mul(int a,int b){
	return 1ll*a*b>=mod?1ll*a*b%mod:a*b;
}
inline int ksm(int a,int b,int res=1){
	for(;b;b>>=1,a=mul(a,a))(b&1)?(res=mul(res,a)):0;return res;
}
int rev[M],A[M],lim,tim;
inline void ntt(int *f,int kd){
	for(int i=0;i<lim;i++)if(i<rev[i])swap(f[i],f[rev[i]]);
	int bas=(kd==1)?g:((mod+1)/3);
	for(int mid=1;mid<lim;mid<<=1){
		int now=ksm(bas,(mod-1)/(mid<<1));
		for(int i=0;i<lim;i+=(mid<<1)){
			int w=1;
			for(int j=0;j<mid;j++,w=mul(w,now)){
				int a0=f[i+j],a1=mul(w,f[i+j+mid]);
				f[i+j]=add(a0,a1),f[i+j+mid]=dec(a0,a1);
			}
		}
	}
	if(kd==-1)for(int i=0,inv=ksm(lim,mod-2);i<lim;i++)f[i]=mul(f[i],inv);
}
int dep[N],val[N],ans[N],maxn,rt,mx,siz[N],son[N],adj[N],nxt[N<<1],to[N<<1],vis[N],cnt,tot;
inline void addedge(int u,int v){
	nxt[++cnt]=adj[u],adj[u]=cnt,to[cnt]=v;
}
void getrt(int u,int fa){
	siz[u]=1,son[u]=0;
	for(int e=adj[u];e;e=nxt[e]){
		int v=to[e];
		if(vis[v]||v==fa)continue;
		getrt(v,u),siz[u]+=siz[v];
		if(siz[v]>son[u])son[u]=siz[v];
	}
	son[u]=max(son[u],maxn-siz[u]);
	if(son[u]<son[rt])rt=u;
}
void getdep(int u,int fa){
	val[++tot]=dep[u],mx=max(mx,dep[u]);
	for(int e=adj[u];e;e=nxt[e]){
		int v=to[e];
		if(vis[v]||v==fa)continue;
		dep[v]=dep[u]+1;
		getdep(v,u);
	}
}
void calc(int u,int l,int f){
	dep[u]=l,mx=tot=0;
	getdep(u,0);
	lim=1,tim=0;
	while(lim<=2*mx)lim<<=1,tim++;
	for(int i=0;i<lim;i++)A[i]=0,rev[i]=(rev[i>>1]>>1)|((i&1)<<(tim-1));
	for(int i=1;i<=tot;i++)A[val[i]]++;
	ntt(A,1);
	for(int i=0;i<lim;i++)A[i]=mul(A[i],A[i]);
	ntt(A,-1);
	for(int i=0;i<=2*mx;i++)ans[i]+=f*A[i];
}
void solve(int u){
	vis[u]=1;
	calc(u,0,1);
	for(int e=adj[u];e;e=nxt[e]){
		int v=to[e];
		if(vis[v])continue;
		calc(v,1,-1),maxn=siz[v];
		getrt(v,rt=0);
		solve(rt);
	}
}
int n;
int main(){
	maxn=son[0]=n=read();
	for(int i=1;i<n;i++){
		int u=read()+1,v=read()+1;
		addedge(u,v),addedge(v,u);
	}
	getrt(1,0);
	solve(rt);
	long double res=0;
	for(int i=0;i<n;i++)res+=(long double)ans[i]/(i+1);
	printf("%.4Lf",res);
} 
posted @ 2019-06-01 18:04  Stargazer_cykoi  阅读(175)  评论(0编辑  收藏  举报