bzoj5093

NTT+组合数学

$把每个点分别按度数考虑,由于有标号,可以得出$

$ans=n*2^{(n-1)*(n-2)}*\sum_{i=1}^{n-1}{C(n-1,i)*i^{k}}$

$本质上是求\sum_{i=1}^{n}{C(n,i)*i^{k}}$

$组合数永远是一个比较好化简的东西,问题在于i^{k}$

$通常有两种方式可以化简,一种利用第二类斯特林数,另一种是利用伯努利数,这里利用第二类斯特林数的组合意义$

$考虑i^{k}的组合意义,相当于把k个不同的物品放进i个不同的箱子$

$S(n,k)表示把n个不同元素拆成k个集合的方案数$

$注意这里要枚举集合数,因为箱子允许空,集合不允许空,所以转化成加法原理$

$考虑转化,枚举多少个箱子有东西,也就是拆成几个集合$

$n^{k}=\sum_{i=1}^{n}{S(k,i)*C(n,i)*i!}$

$得出\sum_{i=1}^{n}{C(n,i)*i^{k}}=\sum_{i=1}^{n}{C(n,i)\sum_{j=1}^{i}{S(k,j)*C(i,j)*j!}}$

$然而并无卵用,这个东西还是不能快速求$

$改变求和顺序\sum_{j=1}^{n}{S(k,j)*j!\sum_{i=j}^{n}{C(n,i)*C(i,j)}}$

$考虑里面组合数的意义,C(n,i)*C(i,j),表示先从n个里选了i个,再从i个里选j个,相当于先选j个,剩下的随便选不选$

$那么就是C(n,j)*2^{n-j},里面的sigma消掉了$

$所以变成\sum_{j=1}^{n}{S(k,j)*j!*C(n,j)*2^{n-j}}$

$问题就变成了如何快速求第二类斯特林数$

$由于给定了k,所以我们可以通过NTT快速预处理出斯特林数$

$考虑容斥原理,S(k,j)表示将k个不同的数划分成j个集合,那么j个集合都非空$

$用容斥原理弱化这个式子,\frac{1}{j!}\sum_{i=0}^{j}{(-1)^{i}*C(j,i)*(j-i)^{k}}$

$先保证至少,那么i个集合是空的,剩下的随便放,那么放k个数的方案就是(j-i)^{k}$

$但是这里的集合是无序的,所以我们还得除一个j!$

$化简一下得出S(k,j)=\sum_{i=0}^{j}{\frac{(-1)^{i}}{i!}*\frac{(j-i)^{k}}{(j-i)!}}$

$这很卷积,那么直接NTT预处理第二类斯特林数,然后O(n)算出答案即可$

 

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 4e5 + 5, P = 998244353;
int n, m, k, len;
ll ans;
ll a[N << 1], b[N << 1], fac[N], inv[N], facinv[N], fac_n[N];
ll power(ll x, ll t) {
    ll ret = 1;
    for(; t; t >>= 1, x = x * x % P) {
        if(t & 1) {
            ret = ret * x % P;
        }
    } 
    return ret;
}
void ntt(ll *a, int f) {
    for(int i = 0; i < n; ++i) {
        int t = 0;
        for(int j = 0; j < len; ++j) {
            if(i >> j & 1) {
                t |= 1 << (len - j - 1); 
            }
        }
        if(i < t) {
            swap(a[i], a[t]);
        }
    }
    for(int l = 2; l <= n; l <<= 1) {
        int m = l >> 1;
        ll w = power(3, f == 1 ? (P - 1) / l : (P - 1) - (P - 1) / l);
        for(int i = 0; i < n; i += l) {
            ll t = 1;
            for(int k = 0; k < m; ++k, t = t * w % P) {
                ll x = a[i + k], y = t * a[i + m + k] % P;
                a[i + k] = (x + y) % P;
                a[i + k + m] = ((x - y) % P + P) % P;
            }
        }       
    }
    if(f == -1) {
        ll inv = power(n, P - 2);
        for(int i = 0; i < n; ++i) {
            a[i] = a[i] * inv % P;
        }
    }
}
int main() {
    scanf("%d%d", &m, &k);
    --m;
    fac[0] = 1;
    inv[1] = 1;
    facinv[0] = 1;
    fac_n[0] = 1;
    for(int i = 1; i <= k; ++i) {
        fac[i] = fac[i - 1] * i % P;
        fac_n[i] = fac_n[i - 1] * (m - i + 1) % P;
        if(i != 1) {
            inv[i] = (P - P / i) * inv[P % i] % P;
        }
        facinv[i] = facinv[i - 1] * inv[i] % P;
    }
    n = 1;
    len = 0;
    while(n <= k * 2) {
        n <<= 1;
        ++len;
    }
    for(int i = 0; i <= k; ++i) {
        a[i] = facinv[i] * ((i & 1) ? -1 : 1);
        a[i] = ((a[i] % P) + P) % P;
        b[i] = power(i, k) * facinv[i] % P;
    }
    ntt(a, 1);
    ntt(b, 1);
    for(int i = 0; i < n; ++i) {
        a[i] = a[i] * b[i] % P;
    }
    ntt(a, -1);
    for(int i = 0; i <= k && i <= m; ++i) {
        ans = (ans + a[i] * fac[i] % P * facinv[i] % P * fac_n[i] % P * power(2, m - i) % P) % P;
    }
    ans = ans * (m + 1) % P * power(2, (ll)m * (m - 1) / 2) % P;
    printf("%lld\n", ans);
    return 0;
}
View Code

 

posted @ 2018-02-23 20:12  19992147  阅读(132)  评论(0编辑  收藏  举报