[EOJ629] 两开花

Description

给定一棵以 \(1\) 为根 \(n\) 个节点的树。

定义 \(f(k)\) :从树上等概率随机选出 \(k\) 个节点,这 \(k\) 个点的虚树大小的期望。

一个点 \(x\) 在这些被选出的 \(k\) 个点的虚树上,当且仅当它满足下列条件至少一个:

  • \(x\) 被选出。
  • 存在两个被选出的节点 \(a,b\),使得 \(\operatorname{lca}(a,b)=x\)

给定 \(m\),求 \(f(1),f(2),\cdots,f(m)\)。 对 \(998244353\) 取模。\(n\leq 4\cdot 10^5\)

Sol

又是套着期望皮的计数题。

对于每个点 \(i\) 求出有多少种方案对答案有贡献即可:

  • \(i\) 被选出,总方案数为 \(C(n-1,k-1)\)
  • \(i\) 至少两个儿子的子树中存在被选出的点。

第二种不太好算,考虑用总方案数减去不合法的方案数。

总方案数就是 \(C(n-1,k)\)

如果点 \(i\) 的子树中没有被选中的,方案数为 \(C(n-sze[i],k)\)

只有一个儿子的子树中有被选中的,可以枚举儿子 \(j\),方案数就是 \(\sum\limits_{j} C(n-sze[i]+sze[j],k)\)

注意到这样的话,\(i\) 子树中没有被选中的方案数被多算了 儿子个数次,所以还需要加上 \(son[i]\times C(n-sze[i],k)\)

所以

\[f(k)=\sum\limits_{i=1}^n C_{n-1}^{k-1}+C_{n-1}^k+(son[i]-1)\times C_{n-sze[i]}^k-\sum_j C_{n-sze[i]+sze[j]}^k \]

\[f(k)=\sum\limits_{i=1}^n C_{n}^{k}+(son[i]-1)\times C_{n-sze[i]}^k-\sum_j C_{n-sze[i]+sze[j]}^k \]

如何对于每个 \(k\) 快速求呢?

观察到式子中的每一项组合数的上标都是 \(k\),所以我们可以开个桶 \(buc[i]\),在形如 \(buc[n-sze[i]]\) 的地方加上 \(son[i]+1\),在 \(buc[n-sze[i]+sze[j]]\)\(-1\)

好处就是,再推一步式子:

\[f(k)=\sum_{i=0}^n buc[i]\cdot C_i^k \]

这就是个卷积的形式,\(\mathbf{NTT}\)优化就吼了。

Code

#pragma GCC optimize(2)
#include<bits/stdc++.h>
using std::min;
using std::max;
using std::swap;
using std::vector;
typedef double db;
typedef long long ll;
#define pb(A) push_back(A)
#define pii std::pair<int,int>
#define all(A) A.begin(),A.end()
#define mp(A,B) std::make_pair(A,B)
const int N=2e6+5;
const int mod=998244353;

int son[N],sze[N],buc[N];
int n,m,cnt,head[N],fac[N];
int a[N],b[N],lim,rev[N],ifac[N];

struct Edge{
    int to,nxt;
}edge[N<<1];

void add(int x,int y){
    edge[++cnt].to=y;
    edge[cnt].nxt=head[x];
    head[x]=cnt;
}

int ksm(int a,int b=mod-2,int ans=1){
    while(b){
        if(b&1) ans=1ll*ans*a%mod;
        a=1ll*a*a%mod;b>>=1;
    } return ans;
}

void ntt(int *f,int g){
    for(int i=1;i<lim;i++) if(i<rev[i]) swap(f[i],f[rev[i]]);
    for(int mid=1;mid<lim;mid<<=1){
        int tmp=ksm(g,(mod-1)/(mid<<1));
        for(int R=mid<<1,j=0;j<lim;j+=R){
            for(int w=1,k=0;k<mid;k++,w=1ll*w*tmp%mod){
                int x=f[j+k],y=1ll*w*f[j+k+mid]%mod;
                f[j+k]=(x+y)%mod,f[j+k+mid]=(mod+x-y)%mod;
            }
        }
    } if(g>3)
        for(int in=ksm(lim),i=0;i<lim;i++) f[i]=1ll*f[i]*in%mod;
}

int getint(){
    int X=0,w=0;char ch=getchar();
    while(!isdigit(ch))w|=ch=='-',ch=getchar();
    while( isdigit(ch))X=X*10+ch-48,ch=getchar();
    if(w) return -X;return X;
}

void init(int n){
    fac[0]=ifac[0]=1;
    for(int i=1;i<=n;i++) fac[i]=1ll*fac[i-1]*i%mod;
    ifac[n]=ksm(fac[n]);
    for(int i=n-1;i;i--) ifac[i]=1ll*ifac[i+1]*(i+1)%mod;
}

void dfs(int now,int fa=0){
    sze[now]=1; int tot=0; buc[n]++;
    for(int i=head[now];i;i=edge[i].nxt){
        int to=edge[i].to;
        if(sze[to]) continue;
        tot++; dfs(to,now);
        sze[now]+=sze[to];
    }
    for(int i=head[now];i;i=edge[i].nxt){
        int to=edge[i].to;
        if(to==fa) continue;
        (buc[n-sze[now]+sze[to]]+=mod-1)%=mod;
    } (buc[n-sze[now]]+=tot-1+mod)%=mod;
}

int C(int n,int m){
    if(n<m) return 0;
    return 1ll*ifac[n]*fac[m]%mod*fac[n-m]%mod;
}

signed main(){
    n=getint(),m=getint(),init(N-5);
    for(int i=1;i<n;i++){
        int x=getint(),y=getint();
        add(x,y),add(y,x);
    } dfs(1);
    lim=1;while(lim<=n+n) lim<<=1;
    for(int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|(i&1?lim>>1:0);
    for(int i=0;i<=n;i++)
        a[n-i]=1ll*buc[i]*fac[i]%mod,
        b[i]=ifac[i];
    ntt(a,3),ntt(b,3);
    for(int i=0;i<lim;i++) a[i]=1ll*a[i]*b[i]%mod;
    ntt(a,(mod+1)/3);
    for(int i=1;i<=m;i++) 
        printf("%lld\n",1ll*a[n-i]*ifac[i]%mod*C(n,i)%mod);
    return 0;
}

posted @ 2019-02-11 20:30  YoungNeal  阅读(206)  评论(0编辑  收藏  举报