【BZOJ4555】求和(TJOI&HEOI2016)-第二类斯特林数+NTT

测试地址:求和
做法:本题需要用到第二类斯特林数+NTT。
从题目中给的递推式或者根据组合数学的知识,第二类斯特林数S(i,j)的组合意义是:将i个有区别的球放入j个无区别的盒子的方案数。由此我们可以得到通项公式:
S(i,j)=1j!k=0j(1)kCjk(jk)i
这其实就是一个容斥的形式,相当于枚举强制哪些盒子是空的,至于要乘一个1j!是因为盒子无区别,而里面算的方案是有区别的。
那么将这个式子代入题目要求的式子,有:
f(n)=i=0nj=0i2jk=0j(1)kCjk(jk)i
将组合数拆开,整理得:
f(n)=i=0nj=0i2jj!k=0j(1)kk!(jk)i(jk)!
我们发现后半部分已经很像一个卷积的形式了,但是因为它还和i有关,所以我们想办法把i换进去。
我们知道当j>iS(i,j)=0,所以上式中j的上限可以换成n,那么就可以把i换进去,得到:
f(n)=j=0n2jj!k=0j(1)kk!i=0n(jk)i(jk)!
那么这个式子的后半部分就是函数g(x)=(1)xx!和函数h(x)=i=0nxix!的卷积了,可以用NTT求出,而求h(x)时,我们发现它是一个等比数列的前缀和,直接用等比数列求和公式求即可。特别地,h(0)=1,h(1)=n+1,直接用公式算的话这两个会算错。
以下是本人代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=998244353;
const ll g=3;
ll n,fac[100010],inv[100010],invfac[100010];
ll a[1000010]={0},b[1000010]={0};
int r[1000010];

ll power(ll a,ll b)
{
    ll s=1,ss=a;
    while(b)
    {
        if (b&1) s=s*ss%mod;
        ss=ss*ss%mod;b>>=1;
    }
    return s;
}

void NTT(ll *a,ll type,int n)
{
    for(int i=0;i<n;i++)
        if (i<r[i]) swap(a[i],a[r[i]]);
    for(int mid=1;mid<n;mid<<=1)
    {
        ll W=power(g,(mod-1)/(mid<<1));
        if (type==-1) W=power(W,mod-2);
        for(int l=0;l<n;l+=(mid<<1))
        {
            ll w=1;
            for(int k=0;k<mid;k++,w=w*W%mod)
            {
                ll x=a[l+k],y=w*a[l+mid+k]%mod;
                a[l+k]=(x+y)%mod;
                a[l+mid+k]=(x-y+mod)%mod;
            }
        }
    }
    if (type==-1)
    {
        ll inv=power(n,mod-2);
        for(int i=0;i<n;i++)
            a[i]=a[i]*inv%mod;
    }
}

int main()
{
    scanf("%lld",&n);

    fac[0]=fac[1]=inv[1]=invfac[0]=invfac[1]=1;
    for(ll i=2;i<=n;i++)
    {
        fac[i]=fac[i-1]*i%mod;
        inv[i]=(mod-mod/i)*inv[mod%i]%mod;
        invfac[i]=invfac[i-1]*inv[i]%mod;
    }

    for(ll i=0;i<=n;i++)
    {
        a[i]=(((i%2)?-1:1)*invfac[i]+mod)%mod;
        if (i==0) b[i]=1;
        if (i==1) b[i]=n+1;
        if (i>1) b[i]=(power(i,n+1)-1+mod)*invfac[i]%mod*inv[i-1]%mod;
    }
    int x=1,bit=0;
    while(x<=(n<<2)) x<<=1,bit++;
    r[0]=0;
    for(int i=1;i<x;i++)
        r[i]=(r[i>>1]>>1)|((i&1)<<(bit-1));
    NTT(a,1,x),NTT(b,1,x);
    for(int i=0;i<x;i++)
        a[i]=a[i]*b[i]%mod;
    NTT(a,-1,x);

    ll ans=0;
    for(ll i=0,j=1;i<=n;i++,j=j*2ll%mod)
    {
        ll tmp=j*fac[i]%mod;
        tmp=tmp*a[i]%mod;
        ans=(ans+tmp)%mod;
    }
    printf("%lld",ans);

    return 0;
}
posted @ 2018-05-01 13:40  Maxwei_wzj  阅读(100)  评论(0编辑  收藏  举报