【XSY3338】game(期望,点分治,FFT)

题面

game

题解

首先可以看出 “等概率选连通块->连通块内等概率选点” 相当于 “全局等概率选点”。

一开始感觉无从下手,但是题目中还是给了一点提示。

题目让我们输出答案乘 \(n!\) 后的结果,于是想到枚举一个 \(1\sim n\) 的排列 \(p_i\) 表示依次选择并删除的点的序列。那么对于某一个特定的 \(p_i\),这种删点方法中所有点被捶的总次数等于 \(\sum\limits_{i=1}^n (p_i所在连通块还剩下的点数)\)

转换一下角度,考虑每个点会被捶多少次。于是所有点被捶的总次数又可以表示为:\(\sum\limits_{i=1}^n\sum\limits_{j=1}^n[删除p_j前p_j与p_i仍连通]=\sum\limits_{i=1}^n\sum\limits_{j=1}^n[删除j前j与i仍连通]\)

其中 “删除j前j与i仍连通” 可以巧妙地转化为 “\(i\)\(j\) 路径上所有点在 \(p\) 序列中出现的位置(相当于被删除的时间)都比 \(j\) 后”,即如果设 \(t_{p_i}=i\),就有 \(t_j=\min\limits_{v\in path(i,j)}t_v\),其中 \(path(i,j)\) 表示 \(i\)\(j\) 的路径。

考虑某个点 \(i\) 在所有 \(p\) 的排列中被捶的总次数,这之中 \(t_j=\min\limits_{v\in path(i,j)}t_v\) 的概率是 \(\dfrac{1}{dist(i,j)}\),其中 \(dist(i,j)\) 表示 \(path(i,j)\) 集合的大小,即 \(i\)\(j\) 路径上的总点数。

于是我们要求的就是 \(\sum\limits_{i=1}^n\sum\limits_{j=1}^n\dfrac{1}{dist(i,j)}\)

对于每一个 \(dis\in [1,n]\) 求出 \(dist(i,j)=dis\) 的点对 \((i,j)\) 数,使用点分治+FFT即可。

#include<bits/stdc++.h>

#define LN 19
#define N 100010
#define INF 0x7fffffff

using namespace std;

namespace modular
{
	const int mod=1000000007;
	inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
	inline int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
	inline int mul(int x,int y){return 1ll*x*y%mod;}
}using namespace modular;

inline int poww(int a,int b)
{
	int ans=1;
	while(b)
	{
		if(b&1) ans=mul(ans,a);
		a=mul(a,a);
		b>>=1;
	}
	return ans;
}

inline int read()
{
	int x=0,f=1;
	char ch=getchar();
	while(ch<'0'||ch>'9')
	{
		if(ch=='-') f=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9')
	{
		x=(x<<1)+(x<<3)+(ch^'0');
		ch=getchar();
	}
	return x*f;
}

const double pi=acos(-1);

typedef vector<int> poly;

struct Complex
{
	double x,y;
	Complex(){};
	Complex(double a,double b){x=a,y=b;}
}F[N<<2],w[LN][N<<2][2];

Complex operator + (Complex a,Complex b){return Complex(a.x+b.x,a.y+b.y);}
Complex operator - (Complex a,Complex b){return Complex(a.x-b.x,a.y-b.y);}
Complex operator * (Complex a,Complex b){return Complex(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}

int n;
int cnt,head[N],nxt[N<<1],to[N<<1];
int nn,rt,maxn,size[N],fa[N];
int sum[N<<1];
bool vis[N];

void init(int limit)
{
	for(int bit=0,mid=1;mid<limit;bit++,mid<<=1)
	{
		Complex gn(cos(pi/mid),sin(pi/mid));
		Complex ign(cos(pi/mid),-sin(pi/mid));
		Complex g(1,0),ig(1,0);
		for(int j=0;j<mid;j++,g=g*gn,ig=ig*ign)
			w[bit][j][0]=g,w[bit][j][1]=ig;
	}
}

void adde(int u,int v)
{
	to[++cnt]=v;
	nxt[cnt]=head[u];
	head[u]=cnt;
}

void getsize(int u,int fa)
{
	size[u]=1;
	for(int i=head[u];i;i=nxt[i])
	{
		int v=to[i];
		if(vis[v]||v==fa) continue;
		getsize(v,u);
		size[u]+=size[v];
	}
}

void getroot(int u,int fa)
{
	int nmax=0;
	for(int i=head[u];i;i=nxt[i])
	{
		int v=to[i];
		if(vis[v]||v==fa) continue;
		getroot(v,u);
		nmax=max(nmax,size[v]);
	}
	nmax=max(nmax,nn-size[u]);
	if(nmax<maxn) rt=u,maxn=nmax;
}

int maxdis;

void getdis(int u,int fa,int dis)
{
	F[dis].x++;
	maxdis=max(maxdis,dis);
	for(int i=head[u];i;i=nxt[i])
	{
		int v=to[i];
		if(vis[v]||v==fa) continue;
		getdis(v,u,dis+1);
	}
}

int rev[N<<2];

void FFT(Complex *a,int limit,int opt)
{
	opt=(opt<0);
	for(int i=0;i<limit;i++)
		rev[i]=(rev[i>>1]>>1)|((i&1)*(limit>>1));
	for(int i=0;i<limit;i++)
		if(i<rev[i]) swap(a[i],a[rev[i]]);
	for(int bit=0,mid=1;mid<limit;bit++,mid<<=1)
	{
		for(int i=0,len=mid<<1;i<limit;i+=len)
		{
			for(int j=0;j<mid;j++)
			{
				Complex x=a[i+j],y=w[bit][j][opt]*a[i+mid+j];
				a[i+j]=x+y,a[i+mid+j]=x-y;
			}
		}
	}
	if(opt)
		for(int i=0;i<limit;i++)
			a[i].x/=limit;
}

void calc(int u,int dis,int tag)
{
	maxdis=0,getdis(u,0,dis);
	int limit=1;
	while(limit<=(maxdis<<1)) limit<<=1;
	FFT(F,limit,1);
	for(int i=0;i<limit;i++)
		F[i]=F[i]*F[i];
	FFT(F,limit,-1);
	for(int i=0;i<limit;i++) sum[i+1]+=tag*(int)(F[i].x+0.5);
	for(int i=0;i<limit;i++) F[i]=Complex(0,0);
}

void solve(int u)
{
	vis[u]=1;
	calc(u,0,1);
	for(int i=head[u];i;i=nxt[i])
	{
		int v=to[i];
		if(vis[v]) continue;
		calc(v,1,-1);
		getsize(v,0);
		nn=size[v],maxn=INF,getroot(v,0);
		fa[rt]=u;
		solve(rt);
	}
}

int main()
{
	n=read();
	int limit=1;
	while(limit<=(n<<1)) limit<<=1;
	init(limit);
	for(int i=1;i<n;i++)
	{
		int u=read(),v=read();
		adde(u,v),adde(v,u);
	}
	getsize(1,0);
	nn=size[1],maxn=INF,getroot(1,0);
	solve(rt);
	int fac=1;
	for(int i=1;i<=n;i++) fac=mul(fac,i);
	int ans=0;
	for(int i=1;i<=n;i++)
		ans=add(ans,mul(sum[i],poww(i,mod-2)));
	printf("%d\n",mul(ans,fac));
	return 0;
}
/*
3
1 2
2 3
*/
posted @ 2022-10-30 11:01  ez_lcw  阅读(27)  评论(0编辑  收藏  举报