[CF960G] Bandit Blues(第一类斯特林数+分治NTT)

Solution

  • \(O(n^2)\) 做法不会的先去看这个
  • 这里只讲如何快速求第一类斯特林数 \(s(n,m)\)
  • 首先有递推式:\(s(i,j)=s(i-1,j-1)+(i-1)*s(i-1,j)\)
  • 为方便卷积写成这样(第二维和为 \(j\)):\(s(i,j)=s(i-1,j-1)*b(i,1)+b(i,0)*s(i-1,j)\)
  • 其中 \(b(i,1)=1,b(i,0)=i-1\)
  • 那么把 \(s(i)\) 看成一个多项式,\(s(i,j)\) 为这个多项式 \(x^j\) 项的系数,初值:\(s(0,0)=1\)
  • \(b(i)\) 同理
  • 那么 \(s(i)=s(i-1)*b(i)\)
  • 于是把 \(s(0)\) ~ \(s(n)\) 都乘起来,得到的多项式就是 \(s(n)\)
  • 这个多项式的 \(x^i\) 项的系数就是 \(s(n,i)\)
  • 分治 \(ntt\) 即可,时间复杂度 \(O(n \log^2n)\)

code

#include <bits/stdc++.h>

using namespace std;

#define ll long long

const int e = 1e6 + 5, mod = 998244353;
int n, a1, b1, fac[e], inv[e], rev[e], lim;
vector<int>g[e];

inline int ksm(int x, int y)
{
	int res = 1;
	while (y)
	{
		if (y & 1) res = (ll)res * x % mod;
		y >>= 1;
		x = (ll)x * x % mod;
	}
	return res;
}

inline void upt(int &x, int y)
{
	x = y;
	if (x >= mod) x -= mod;
}

inline void fft(int n, int *a, int opt)
{
	int i, j, k, r = (opt == 1 ? 3 : (mod + 1) / 3);
	for (i = 0; i < n; i++)
	if (i < rev[i]) swap(a[i], a[rev[i]]);
	for (k = 1; k < n; k <<= 1)
	{
		int w0 = ksm(r, (mod - 1) / (k << 1));
		for (i = 0; i < n; i += (k << 1))
		{
			int w = 1;
			for (j = 0; j < k; j++)
			{
				int b = a[i + j], c = (ll)w * a[i + j + k] % mod;
				upt(a[i + j], b + c);
				upt(a[i + j + k], b + mod - c);
				w = (ll)w * w0 % mod;
			}
		}
	}
}

inline void solve(int l, int r)
{
	if (l >= r) return;
	int i, mid = l + r >> 1;
	solve(l, mid);
	solve(mid + 1, r);
	static int a[266666], b[266666], c[266666];
	int k = 0, la = g[l].size(), lb = g[mid + 1].size();
	lim = 1;
	while (lim < la + lb - 1)
	{
		lim <<= 1;
		k++;
	}
	for (i = 0; i < lim; i++) 
	{
		a[i] = b[i] = 0;
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << k - 1);
	}
	for (i = 0; i < la; i++) a[i] = g[l][i];
	for (i = 0; i < lb; i++) b[i] = g[mid + 1][i];
	fft(lim, a, 1);
	fft(lim, b, 1);
	for (i = 0; i < lim; i++) a[i] = (ll)a[i] * b[i] % mod;
	fft(lim, a, -1);
	int tot = ksm(lim, mod - 2);
	for (i = 0; i < lim; i++) a[i] = (ll)a[i] * tot % mod;
	g[l].clear(); 
	for (i = 0; i < la + lb - 1; i++) g[l].push_back(a[i]); 
}

inline int c(int x, int y)
{
	if (x < y) return 0;
	return (ll)fac[x] * inv[y] % mod * inv[x - y] % mod;
}

int main()
{
	int i;
	cin >> n >> a1 >> b1;
	fac[0] = 1;
	for (i = 1; i <= n; i++) fac[i] = (ll)fac[i - 1] * i % mod;
	inv[n] = ksm(fac[n], mod - 2);
	for (i = n - 1; i >= 0; i--) inv[i] = (ll)inv[i + 1] * (i + 1) % mod;
	int res = c(a1 + b1 - 2, a1 - 1);
	g[0].push_back(1);
	for (i = 1; i <= n; i++)
	{
		g[i].push_back(i - 1);
		g[i].push_back(1);
	}
	solve(0, n - 1);
	if (a1 + b1 - 2 < g[0].size()) res = (ll)res * g[0][a1 + b1 - 2] % mod;
	else res = 0;
	cout << res << endl;
	return 0;
}
posted @ 2020-01-15 13:45  花淇淋  阅读(133)  评论(0编辑  收藏  举报