[BZOJ3451]Normal(点分治+FFT)

[BZOJ3451]Normal(点分治+FFT)

题面

给你一棵 n个点的树,对这棵树进行随机点分治,每次随机一个点作为分治中心。定义消耗时间为每层分治的子树大小之和,求消耗时间的期望。

分析

根据期望的线性性,答案是\(\sum_{i=1}^n(i的期望子树大小)=\sum_{i=1}^n \sum_{j=1}^n [j在i的点分治子树内]\)

考虑j在i的点分治子树内的条件,显然i到j的路径上的所有点中,i是第一个被选择为分治中心的。否则如果选的点不是i,那么i和j会被分到两棵子树中。第一个被选择的的概率是\(\frac{1}{dist(i,j)+1}\)(\(dist(i,j)\)表示i到j的距离)。那么上式就可以写成\(\sum_{i=1}^n \sum_{j=1}^n \frac{1}{dist(i,j)+1}\)

转换一下,设\(cnt[d]\)表示\(dist(i,j)=d\)\((i,j)\)个数,那么答案为\(\sum_{d=0}^{n-1} \frac{cnt[d]}{d+1}\)。考虑如何求\(cnt[k]\)

我们在点分治的过程中,dfs出深度为i的节点个数cd[i]。那么求经过根节点的答案的时候就是\(cnt[i]=\sum_{j=0}^i cd[j]cd[i-j]\).容易看出这是一个卷积的形式,直接用cd和自身FFT求卷积即可。

注意最后要像一般的点分治一样容斥一下.

时间复杂度满足递推式\(T(n)=2T(\frac{n}{2})+\frac{1}{2}n\log n\).根据主定理的第二种情况,答案是\(\Theta (n\log^2 n)\)

代码

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#define maxn 200000
using namespace std;
typedef long double db;
typedef long long ll;
const db pi=acos(-1.0);
struct com{//复数类
    double real;
    double imag;
    com(){

    } 
    com(double _real,double _imag){
        real=_real;
        imag=_imag;
    }
    com(double x){
        real=x;
        imag=0;
    }
    void operator = (const com x){
        this->real=x.real;
        this->imag=x.imag;
    }
    void operator = (const double x){
        this->real=x;
        this->imag=0;
    }
    friend com operator + (com p,com q){
        return com(p.real+q.real,p.imag+q.imag);
    }
    friend com operator + (com p,double q){
        return com(p.real+q,p.imag);
    }
    void operator += (com q){
        *this=*this+q;
    }
    void operator += (double q){
        *this=*this+q;
    }
    friend com operator - (com p,com q){
        return com(p.real-q.real,p.imag-q.imag);
    }
    friend com operator - (com p,double q){
        return com(p.real-q,p.imag);
    }
    void operator -= (com q){
        *this=*this-q;
    }
    void operator -= (double q){
        *this=*this-q;
    }
    friend com operator * (com p,com q){
        return com(p.real*q.real-p.imag*q.imag,p.real*q.imag+p.imag*q.real);
    }
    friend com operator * (com p,double q){
        return com(p.real*q,p.imag*q);
    } 
    void operator *= (com q){
        *this=(*this)*q;
    }
    void operator *= (double q){
        *this=(*this)*q;
    }
    friend com operator / (com p,double q){
        return com(p.real/q,p.imag/q);
    } 
    void operator /= (double q){
        *this=(*this)/q;
    } 
    void print(){
        printf("%lf + %lf i ",real,imag);
    }
};
void fft(com *x,int n,int type){
	static int rev[maxn+5];
	int dn=1,k=0;
	while(dn<n){
		dn*=2;
		k++;
	}
	for(int i=0;i<n;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
	for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]);
	for(int len=1;len<n;len*=2){
		int sz=len*2;
		com wn1=com(cos(2*pi/sz),sin(2*pi/sz)*type);
		for(int l=0;l<n;l+=sz){
			int r=l+len-1;
			com wnk=1;
			for(int i=l;i<=r;i++){
				com tmp=x[i+len];
				x[i+len]=x[i]-wnk*tmp;
				x[i]=x[i]+wnk*tmp;
				wnk*=wn1;
			}
		}
	}
	if(type==-1) for(int i=0;i<n;i++) x[i]/=n;
}
void mul(com *a,com *b,com *ans,int n){//封装多项式乘法
	fft(a,n,1);
	if(a!=b) fft(b,n,1);
	for(int i=0;i<n;i++) ans[i]=a[i]*b[i];
	fft(ans,n,-1);
}


struct edge{
	int from;
	int to;
	int next;
}E[maxn*2+5];
int head[maxn+5];
int esz=1;
void add_edge(int u,int v){
	esz++;
	E[esz].from=u;
	E[esz].to=v;
	E[esz].next=head[u];
	head[u]=esz;
}

bool vis[maxn+5];
int sz[maxn+5],f[maxn+5];
int root;
int tot_sz;
void get_root(int x,int fa){
	sz[x]=1;
	f[x]=0;
	for(int i=head[x];i;i=E[i].next){
		int y=E[i].to;
		if(y!=fa&&!vis[y]){
			get_root(y,x);
			sz[x]+=sz[y];
			f[x]=max(f[x],sz[y]);
		}
	}
	f[x]=max(f[x],tot_sz-sz[x]);
	if(f[x]<f[root]) root=x;
}

int maxd;
com ff[maxn+5];//当前子树中深度为x的节点个数
com res[maxn+5];
ll cnt[maxn+5];

void get_deep(int x,int fa,int d){
	ff[d]+=1;
	maxd=max(maxd,d);
	for(int i=head[x];i;i=E[i].next){
		int y=E[i].to;
		if(y!=fa&&!vis[y]){
			get_deep(y,x,d+1);
		}
	}
}

void calc(int x,int d,int type){
	maxd=0;
	get_deep(x,0,d);
	int dn=1,k=0;
	while(dn<=maxd*2){
		dn*=2;
		k++;
	}
	mul(ff,ff,res,dn);//卷积
	for(int i=0;i<=maxd*2;i++) cnt[i]+=(ll)(res[i].real+0.5)*type;//用卷积结果更新cnt
	for(int i=0;i<=dn;i++) ff[i]=0;
}

void solve(int x){
	vis[x]=1;
	calc(x,0,1);
	for(int i=head[x];i;i=E[i].next){
		int y=E[i].to;
		if(!vis[y]){
			calc(y,1,-1);//容斥,减去一条边经过两次的答案
			root=0;
			tot_sz=sz[y];
			get_root(y,0);
			solve(root);
		}
	}
}

int n;
int main(){
	int u,v;
	scanf("%d",&n);
	for(int i=1;i<n;i++){
		scanf("%d %d",&u,&v);
		u++;
		v++;
		add_edge(u,v);
		add_edge(v,u);
	}
	f[0]=n+1;
	root=0;
	tot_sz=n;
	get_root(1,0);
	solve(root);
	db ans=0;
	for(int i=0;i<=n-1;i++){
		ans+=(db)cnt[i]*1/(i+1);
	}
	printf("%.4Lf\n",ans);
}
posted @ 2019-10-13 11:51  birchtree  阅读(288)  评论(0编辑  收藏  举报