bzoj5093
NTT+组合数学
$把每个点分别按度数考虑,由于有标号,可以得出$
$ans=n*2^{(n-1)*(n-2)}*\sum_{i=1}^{n-1}{C(n-1,i)*i^{k}}$
$本质上是求\sum_{i=1}^{n}{C(n,i)*i^{k}}$
$组合数永远是一个比较好化简的东西,问题在于i^{k}$
$通常有两种方式可以化简,一种利用第二类斯特林数,另一种是利用伯努利数,这里利用第二类斯特林数的组合意义$
$考虑i^{k}的组合意义,相当于把k个不同的物品放进i个不同的箱子$
$S(n,k)表示把n个不同元素拆成k个集合的方案数$
$注意这里要枚举集合数,因为箱子允许空,集合不允许空,所以转化成加法原理$
$考虑转化,枚举多少个箱子有东西,也就是拆成几个集合$
$n^{k}=\sum_{i=1}^{n}{S(k,i)*C(n,i)*i!}$
$得出\sum_{i=1}^{n}{C(n,i)*i^{k}}=\sum_{i=1}^{n}{C(n,i)\sum_{j=1}^{i}{S(k,j)*C(i,j)*j!}}$
$然而并无卵用,这个东西还是不能快速求$
$改变求和顺序\sum_{j=1}^{n}{S(k,j)*j!\sum_{i=j}^{n}{C(n,i)*C(i,j)}}$
$考虑里面组合数的意义,C(n,i)*C(i,j),表示先从n个里选了i个,再从i个里选j个,相当于先选j个,剩下的随便选不选$
$那么就是C(n,j)*2^{n-j},里面的sigma消掉了$
$所以变成\sum_{j=1}^{n}{S(k,j)*j!*C(n,j)*2^{n-j}}$
$问题就变成了如何快速求第二类斯特林数$
$由于给定了k,所以我们可以通过NTT快速预处理出斯特林数$
$考虑容斥原理,S(k,j)表示将k个不同的数划分成j个集合,那么j个集合都非空$
$用容斥原理弱化这个式子,\frac{1}{j!}\sum_{i=0}^{j}{(-1)^{i}*C(j,i)*(j-i)^{k}}$
$先保证至少,那么i个集合是空的,剩下的随便放,那么放k个数的方案就是(j-i)^{k}$
$但是这里的集合是无序的,所以我们还得除一个j!$
$化简一下得出S(k,j)=\sum_{i=0}^{j}{\frac{(-1)^{i}}{i!}*\frac{(j-i)^{k}}{(j-i)!}}$
$这很卷积,那么直接NTT预处理第二类斯特林数,然后O(n)算出答案即可$
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int N = 4e5 + 5, P = 998244353; int n, m, k, len; ll ans; ll a[N << 1], b[N << 1], fac[N], inv[N], facinv[N], fac_n[N]; ll power(ll x, ll t) { ll ret = 1; for(; t; t >>= 1, x = x * x % P) { if(t & 1) { ret = ret * x % P; } } return ret; } void ntt(ll *a, int f) { for(int i = 0; i < n; ++i) { int t = 0; for(int j = 0; j < len; ++j) { if(i >> j & 1) { t |= 1 << (len - j - 1); } } if(i < t) { swap(a[i], a[t]); } } for(int l = 2; l <= n; l <<= 1) { int m = l >> 1; ll w = power(3, f == 1 ? (P - 1) / l : (P - 1) - (P - 1) / l); for(int i = 0; i < n; i += l) { ll t = 1; for(int k = 0; k < m; ++k, t = t * w % P) { ll x = a[i + k], y = t * a[i + m + k] % P; a[i + k] = (x + y) % P; a[i + k + m] = ((x - y) % P + P) % P; } } } if(f == -1) { ll inv = power(n, P - 2); for(int i = 0; i < n; ++i) { a[i] = a[i] * inv % P; } } } int main() { scanf("%d%d", &m, &k); --m; fac[0] = 1; inv[1] = 1; facinv[0] = 1; fac_n[0] = 1; for(int i = 1; i <= k; ++i) { fac[i] = fac[i - 1] * i % P; fac_n[i] = fac_n[i - 1] * (m - i + 1) % P; if(i != 1) { inv[i] = (P - P / i) * inv[P % i] % P; } facinv[i] = facinv[i - 1] * inv[i] % P; } n = 1; len = 0; while(n <= k * 2) { n <<= 1; ++len; } for(int i = 0; i <= k; ++i) { a[i] = facinv[i] * ((i & 1) ? -1 : 1); a[i] = ((a[i] % P) + P) % P; b[i] = power(i, k) * facinv[i] % P; } ntt(a, 1); ntt(b, 1); for(int i = 0; i < n; ++i) { a[i] = a[i] * b[i] % P; } ntt(a, -1); for(int i = 0; i <= k && i <= m; ++i) { ans = (ans + a[i] * fac[i] % P * facinv[i] % P * fac_n[i] % P * power(2, m - i) % P) % P; } ans = ans * (m + 1) % P * power(2, (ll)m * (m - 1) / 2) % P; printf("%lld\n", ans); return 0; }