uoj#351. 新年的叶子(概率期望)

传送门

数学还是太差了,想了半天都没想出来

首先有一个定理,如果直径(这里考虑经过的点数)为奇数,所有直径有同一个中点,如果直径为偶数,所有直径有同一条最中间的边。这个可以用反证法,假设不成立的话直径会变长

如果直径为奇数,那么我们可以以共同经过的那个点为根,把所有在直径上的叶子按不同的子树分类,如果某两个叶子在同一棵子树,那么它们不可能构成直径,如果在不同的子树,那么必定能构成直径。所以把所有在直径上的叶子按不同的子树分为若干个集合

如果是偶数,那么就直接分为两个集合

我们现在要求的,就是这些集合中只剩一个集合没有被完全染黑的期望时间

可以考虑容斥,枚举一个集合\(i\),让它成为没有被完全染黑的那个集合,那么我们现在只关心其它所有集合被全部染黑的时间,设\(m\)为叶子总数,\(s\)为剩下的集合中点的总数,设\(f_i\)为还剩下\(i\)个点没有被染色时染一个点的期望时间,那么有\(f_i=1+\frac{m-i}{m}f_i\),所以\(f_i=\frac{m}{i}\),那么剩下的集合全部被染色的时间就是\(\sum_{i=1}^{s}\frac{m}{i}\),预处理一下就可以了

然而按我们上面的枚举方法,有可能会有所有集合都被染黑的情况。考虑每一种所有集合都被染黑的方案,如果最后一个被染黑的集合黑了,那么其他集合肯定也黑了。所以每一个方案中每一个最后被染黑的集合会被其它所有集合枚举到\(t-1\)次(\(t\)为集合的个数),也就是说每一种全被染黑的方案会被统计\(t-1\)次,减掉就好了

//minamoto
#include<bits/stdc++.h>
#define R register
#define fp(i,a,b) for(R int i=a,I=b+1;i<I;++i)
#define fd(i,a,b) for(R int i=a,I=b-1;i>I;--i)
#define go(u) for(int i=head[u],v=e[i].v;i;i=e[i].nx,v=e[i].v)
using namespace std;
char buf[1<<21],*p1=buf,*p2=buf;
inline char getc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++;}
int read(){
    R int res,f=1;R char ch;
    while((ch=getc())>'9'||ch<'0')(ch=='-')&&(f=-1);
    for(res=ch-'0';(ch=getc())>='0'&&ch<='9';res=res*10+ch-'0');
    return res*f;
}
char sr[1<<21],z[20];int C=-1,Z=0;
inline void Ot(){fwrite(sr,1,C+1,stdout),C=-1;}
void print(R int x){
    if(C>1<<20)Ot();if(x<0)sr[++C]='-',x=-x;
    while(z[++Z]=x%10+48,x/=10);
    while(sr[++C]=z[Z],--Z);sr[++C]='\n';
}
const int N=5e5+5,P=998244353;
inline int add(R int x,R int y){return x+y>=P?x+y-P:x+y;}
inline int dec(R int x,R int y){return x-y<0?x-y+P:x-y;}
inline int mul(R int x,R int y){return 1ll*x*y-1ll*x*y/P*P;}
int ksm(R int x,R int y){
	R int res=1;
	for(;y;y>>=1,x=mul(x,x))if(y&1)res=mul(res,x);
	return res;
}
struct eg{int v,nx;}e[N<<1];int head[N],tot;
inline void add_edge(R int u,R int v){e[++tot]={v,head[u]},head[u]=tot;}
int inv[N],sum[N],dep[N],fa[N],deg[N],st[N];
int n,tmp,res,u,v,len,m,top,s,ans;
void dfs(int u,int fat,int &x){
	fa[u]=fat,dep[u]=dep[fat]+1;
	if(dep[u]==len/2)++x;
	go(u)if(v!=fat)dfs(v,u,x);
}
int main(){
//	freopen("testdata.in","r",stdin);
	n=read();
	fp(i,1,n-1){
		u=read(),v=read(),add_edge(u,v),add_edge(v,u);
		++deg[u],++deg[v];
	}
	int rt=1,tl=1;
	dfs(1,0,tmp);
	fp(i,1,n){
		if(dep[i]>dep[rt])rt=i;
		if(deg[i]==1)++m;
	}
	dfs(rt,0,tmp);
	fp(i,1,n)if(dep[i]>dep[tl])tl=i;
	len=dep[tl];
	inv[1]=1,sum[1]=m;
	fp(i,2,n){
		inv[i]=1ll*inv[P%i]*(P-P/i)%P;
		sum[i]=add(sum[i-1],mul(m,inv[i]));
	}
	if(len&1){
		int x=0;
		for(R int i=tl;i;i=fa[i])if(dep[i]==((len+1)>>1))x=i;
		dep[x]=0;
		go(x){
			dfs(v,x,tmp=0);
			if(tmp)st[++top]=tmp,s+=tmp;
		}
	}else{
		int x1=0,x2=0;
		for(R int i=tl;i;i=fa[i]){
			if(dep[i]==(len>>1))x1=i;
			if(dep[i]==(len>>1)+1)x2=i;
		}
		dep[x2]=0,dfs(x1,x2,st[++top]);
		dep[x1]=0,dfs(x2,x1,st[++top]);
		s=st[1]+st[2];
	}
	fp(i,1,top)ans=add(ans,sum[s-st[i]]);
	printf("%d\n",dec(ans,mul(top-1,sum[s])));
	return 0;
}
posted @ 2019-01-10 10:11  bztMinamoto  阅读(261)  评论(0编辑  收藏  举报
Live2D