第二类斯特林数学习笔记

第二类$ Stirling$数是把包含n个元素的集合划分为正好k个非空子集的方法的数目。   
递推公式为$ S(n,k) = S(n-1,k-1) + kS(n-1,k).$


这类斯特林数有一个很好的性质:

$ x^k=\sum\limits_{j=0}^kC_x^jS(k,j)j!$

其意义是$ k$个球放入$ x$个有标号盒子的方案数,枚举空盒的数量,乘上阶乘以及选出这些空盒的方案即可


$ stirling$数可以通过组合意义展开:

$ S(n,m)= \frac{1}{m!}*\sum\limits_{k=0}^m(-1)^k(m-k)^nC_m^k$

我们枚举空盒$ k$的个数,就可以容斥出恰好划分成$ m$个非空子集的方案数

除$ m!$是因为划分的集合是无标号集合而容斥的集合是带标号集合


这个式子可以化为卷积形式:

$ S(n,m)= \frac{1}{m!}*\sum\limits_{k=0}^m(-1)^k(m-k)^nC_m^k$

$ S(n,m)= \frac{1}{m!}*\sum\limits_{k=0}^m(-1)^k(m-k)^n\frac{m!}{k!(m-k)!}$

$ S(n,m)= \sum\limits_{k=0}^m(-1)^k(m-k)^n\frac{1}{k!(m-k)!}$

$ S(n,m)= \sum\limits_{k=0}^m\frac{(-1)^k}{k!}*\frac{(m-k)^n}{(m-k)!}$

令$ A(x)=\frac{(-1)^k}{k!},B(x)=\frac{m^n}{m!}$

则有$ S(n,m)=\sum\limits_{k=0}^mA(k)B(m-k)$

是一个经典的卷积模型,可以$ FFT/NTT$在$ O(n log n)$的时间复杂度计算$ S(n,0..n)$


例题:

$ 1$.给定一棵$ n$个节点的树,每个点有$ yd[i]$的概率原地不动,否则以均等的概率往周围的点移动。问每个点到根的移动次数的$ k$次方的期望

来源:联考模拟赛

数据范围:$ n,k<=100000,nk<=1000000$


$ solution$

考虑$ DP$,用$ val[i]^j$表示点$ i$到根的路径上的$ j$次方值的期望

每次枚举每个点回到自己的概率以及除了回到自己的情况以外的常数贡献(为防止父亲对自己产生影响先忽略父亲影响部分的值)

最后从上往下加上父亲对自己的贡献

发现复杂度瓶颈是从$ val[i]^j$推出$ (val[i]+1)^j$

这部分只能二项式展开导致我们不得不枚举$ k$然后每次暴力二项式展开$ O(k)$转移

总复杂度$ O(nk^2)$


考虑转化成第二类斯特林数

我们把$ val[i]^j$二项式展开得到$ val[i]^j=\sum\limits_{k=0}^jC_{val[i]}^kS(j,k)k!$

我们只需要维护每个点到根的$ C_{val[x]}^j $即可

由于组合数有$ C_{val[x]+1}^j=C_{val[x]}^{j}+C_{val[x]}^{j-1}$

因此这可以做到$ O(1)$转移,时间复杂度$ O(nk+k^2)$其中$ k^2$是求斯特林数的复杂度

由于我们只需要求$ S(k,0...k)$,我们可以直接用上面的式子$ NTT$在$ O(n log n)$的时间内求出

总复杂度$ O(nk+k log k)$,可以通过此题


 

$ my \ code$

#include<ctime>
#include<cmath>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<queue>
#define p 998244353
#define M 200010
#define rt register int
#define ll long long
using namespace std;
inline ll read(){
    ll x = 0; char zf = 1; char ch = getchar();
    while (ch != '-' && !isdigit(ch)) ch = getchar();
    if (ch == '-') zf = -1, ch = getchar();
    while (isdigit(ch)) x = x * 10 + ch - '0', ch = getchar(); return x * zf;
}
void write(ll y){if(y<0)putchar('-'),y=-y;if(y>9)write(y/10);putchar(y%10+48);}
void writeln(const ll y){write(y);putchar('\n');}
int i,j,k,m,n,x,y,z,cnt;
int F[M],L[M],N[M],a[M],fa[M],d[M];
ll yd[M],inv[1000010];int val[100010][2];
void add(int x,int y){
    a[++k]=y;
    if(!F[x])F[x]=k;
    else N[L[x]]=k;
    L[x]=k;
}
void dfs(int x,int pre){
    fa[x]=pre;
    for(rt i=F[x];i;i=N[i])if(a[i]!=pre)dfs(a[i],x);
}
ll ksm(ll x,ll y){
    if(!y)return 1;ll ew=1;
    while(y>1){
        if(y&1)y--,ew=x*ew%p;
        y>>=1,x=x*x%p;
    }return x*ew%p;
}
inline int calc(int x,int y){
    return val[x][y&1^1];
}
ll stop[M];
ll without(int x,int y){
    ll ans=(yd[x]*calc(x,y)+(1ll-yd[x])*inv[d[x]]%p*calc(fa[x],y))%p;
    for(rt i=F[x];i;i=N[i])if(a[i]!=fa[x])(ans+=(1ll-yd[x])*inv[d[x]]%p*(without(a[i],y)+calc(a[i],y)))%=p;
    return val[x][y&1]=ans*stop[x]%p;
}
void solve(int x,int y){
    if(x!=1)(val[x][y&1]+=val[fa[x]][y&1])%=p;
    for(rt i=F[x];i;i=N[i])if(a[i]!=fa[x])solve(a[i],y);
}
ll jc[100010],S[100010],ans[100010];
struct poly{
    int n,m,lim;
    ll a[2200010],b[2200010];int R[2200010];
    void init(int nn,int mm){
        n=m=mm;a[0]=1;b[0]=0;
        for(rt i=1;i<=mm;i++){
            a[i]=ksm(jc[i],p-2);
            if(i&1)a[i]=p-a[i];
            b[i]=ksm(i,nn)*ksm(jc[i],p-2)%p;
        }
        lim=1;while(lim<=n+m)lim<<=1;
        for(rt i=1;i<lim;i++)R[i]=(R[i>>1]>>1)|((i&1)*(lim>>1));
    }
    ll ksm(ll x,ll y){
        if(!y)return 1;ll ew=1;
        while(y>1){
            if(y&1)y--,ew=x*ew%p;
            y>>=1,x=x*x%p;
        }
        return x*ew%p;
    }
    void NTT(ll *A,int fla){
        for(rt i=0;i<lim;i++)if(i>R[i])swap(A[i],A[R[i]]);
        for(rt i=1;i<lim;i<<=1){
            ll w=ksm(3,998244352/2/i);
            if(fla==-1)w=ksm(w,p-2);
            for(rt j=0;j<lim;j+=i<<1){
                ll K=1;
                for(rt k=0;k<i;k++,K=K*w%p){
                    const ll x=A[j+k],y=K*A[i+j+k]%p;
                    A[j+k]=(x+y)%p;A[i+j+k]=(x-y)%p;
                }
            }
        }
    }
    void main(int nn,int mm){
        init(nn,mm);
        NTT(a,1);NTT(b,1);
        for(rt i=0;i<lim;i++)a[i]=a[i]*b[i]%p;
        NTT(a,-1);
        for(rt i=0;i<=n;i++)S[i]=(a[i]*ksm(lim,p-2)%p+p)%p;
    }
}NTT;
int main(){ 
    n=read();m=read();inv[0]=inv[1]=jc[0]=jc[1]=1;
    if(m==0){
        for(rt i=2;i<=n;i++)writeln(1);
        return 0;
    }   
    for(rt i=2;i<=n;i++)inv[i]=inv[p%i]*(p-p/i)%p;
    for(rt i=2;i<=m;i++)jc[i]=jc[i-1]*i%p;
    NTT.main(m,m);
 
    for(rt i=1;i<n;i++){
        x=read();y=read();
        d[x]++;d[y]++;
        add(x,y);add(y,x);
    }
    for(rt i=1;i<=n;i++)val[i][0]=1;
 
    dfs(1,1);
    for(rt i=2;i<=n;i++)yd[i]=read()*ksm(1000000,p-2)%p;
    for(rt x=2;x<=n;x++){
        stop[x]=yd[x];
        for(rt i=F[x];i;i=N[i])if(a[i]!=fa[x])(stop[x]+=(1ll-yd[x])*inv[d[x]]%p)%=p;
        (stop[x]+=p)%=p;        
    }
    for(rt i=2;i<=n;i++)stop[i]=ksm(p+1-stop[i],p-2);
    for(rt i=0;i<=m;i++){
        if(i){
            for(rt j=1;j<=n;j++)val[j][i&1]=0;
            for(rt j=F[1];j;j=N[j])without(a[j],i);
            solve(1,i);
        }
        for(rt j=1;j<=n;j++)(ans[j]+=val[j][i&1]*S[i]%p*jc[i])%=p;
    }
     
    for(rt i=2;i<=n;i++)writeln((ans[i]+p)%p);
    return 0;
}

 



posted @ 2018-10-30 20:20  Kananix  阅读(872)  评论(0编辑  收藏  举报

Contact with me