CF960G Bandit Blues 第一类斯特林数+分治+FFT
题目传送门
https://codeforces.com/contest/960/problem/G
题解
首先整个排列的最大值一定是 \(A\) 个前缀最大值的最后一个,也是 \(B\) 个后缀最大值的最后一个。
那么枚举一下最大值的位置为 \(i\),那么左右两边各选一些数的方案数为 \(\binom {n-1}{i-1}\)。
然后,左边有 \(i-1\) 个数,要分成 \(A-1\) 个部分,每一个部分的第一个数是所有数中最大的,并且每一个部分之间的最大值要递增。
可以发现这个问题等价于把 \(i-1\) 个数分成 \(A-1\) 个环——不是排列是因为第一个数的位置是固定的。因此可以很容易地写出答案
\[\sum_{i=1}^{n-1} \binom{n-1}{i-1}\begin{bmatrix}i-1\\A-1\end{bmatrix}\begin{bmatrix}n-i\\B-1\end{bmatrix}
\]
发现既然是先把一堆数分成两份再每一份分成 \(A-1\) 和 \(B-1\) 个环,那么也可以看成先分成 \(A+B-2\) 个环然后再把环分成两份。
因此方案数等价于
\[\begin{bmatrix}n-1\\A+B-2\end{bmatrix}\binom{A+B-2}{A-1}
\]
求第一类斯特林数可以用分治+FFT。
#include<bits/stdc++.h>
#define fec(i, x, y) (int i = head[x], y = g[i].to; i; i = g[i].ne, y = g[i].to)
#define dbg(...) fprintf(stderr, __VA_ARGS__)
#define File(x) freopen(#x".in", "r", stdin), freopen(#x".out", "w", stdout)
#define fi first
#define se second
#define pb push_back
template<typename A, typename B> inline char smax(A &a, const B &b) {return a < b ? a = b , 1 : 0;}
template<typename A, typename B> inline char smin(A &a, const B &b) {return b < a ? a = b , 1 : 0;}
typedef long long ll; typedef unsigned long long ull; typedef std::pair<int, int> pii;
template<typename I>
inline void read(I &x) {
int f = 0, c;
while (!isdigit(c = getchar())) c == '-' ? f = 1 : 0;
x = c & 15;
while (isdigit(c = getchar())) x = (x << 1) + (x << 3) + (c & 15);
f ? x = -x : 0;
}
const int N = 4e5 + 7;
const int P = 998244353;
const int Gi = 332748118;
const int G = 3;
int n, Aa, Bb, nlg;
int A[N], B[N], pw[20][N], cc[20];
int fac[N], inv[N], ifac[N];
inline void ycl(const int &n = ::n) {
fac[0] = 1; for (int i = 1; i <= n; ++i) fac[i] = (ll)fac[i - 1] * i % P;
inv[1] = 1; for (int i = 2; i <= n; ++i) inv[i] = (ll)(P - P / i) * inv[P % i] % P;
ifac[0] = 1; for (int i = 1; i <= n; ++i) ifac[i] = (ll)ifac[i - 1] * inv[i] % P;
}
inline int C(int x, int y) {
if (x < y) return 0;
return (ll)fac[x] * ifac[y] % P * ifac[x - y] % P;
}
inline int smod(int x) { return x >= P ? x - P : x; }
inline void sadd(int &x, const int &y) { x += y; x >= P ? x -= P : x; }
inline int fpow(int x, int y) {
int ans = 1;
for (; y; y >>= 1, x = (ll)x * x % P) if (y & 1) ans = (ll)ans * x % P;
return ans;
}
inline void NTT(int *a, int n, int f) {
for (int i = 0, j = 0; i < n; ++i) {
if (i > j) std::swap(a[i], a[j]);
for (int l = n >> 1; (j ^= l) < l; l >>= 1) ;
}
for (int i = 1; i < n; i <<= 1) {
int w = fpow(f > 0 ? G : Gi, (P - 1) / (i << 1));
for (int j = 0; j < n; j += i << 1)
for (int k = 0, e = 1; k < i; ++k, e = (ll)e * w % P) {
int x = a[j + k], y = (ll)e * a[i + j + k] % P;
a[j + k] = smod(x + y), a[i + j + k] = smod(x + P - y);
}
}
if (f < 0) for (int i = 0, p = fpow(n, P - 2); i < n; ++i) a[i] = (ll)a[i] * p % P;
}
inline void Mul(int *a, int *b, int *c, int n, int m) {
int l = 1;
while (l <= n + m) l <<= 1;
for (int i = 0; i <= n; ++i) A[i] = a[i];
for (int i = n + 1; i < l; ++i) A[i] = 0;
for (int i = 0; i <= m; ++i) B[i] = b[i];
for (int i = m + 1; i < l; ++i) B[i] = 0;
NTT(A, l, 1), NTT(B, l, 1);
for (int i = 0; i < l; ++i) A[i] = (ll)A[i] * B[i] % P;
NTT(A, l, -1);
for (int i = 0; i <= n + m; ++i) c[i] = A[i];
}
int a[N], b[N], c[N];
inline int Calc_Strling1(int n, int m) {
if (!n) return m == 0;
if (m > n || m <= 0) return 0;
a[1] = 1;
for (int k = 0; k < nlg; ++k) {
int l = cc[k];
for (int i = 0; i <= l; ++i) b[i] = (ll)a[l - i] * fac[l - i] % P
for (int i = 0; i <= l; ++i) c[i] = (ll)pw[k][i] * ifac[i] % P;
Mul(b, c, b, l, l);
for (int i = 0; i <= l; ++i) c[i] = (ll)b[l - i] * ifac[i] % P;
Mul(a, c, a, l, l);
if (cc[k + 1] & 1) {
l = cc[k + 1];
for (int i = l; i; --i) a[i] = (a[i - 1] + (l - 1ll) * a[i]) % P;
a[0] = (l - 1ll) * a[0] % P;
}
}
return a[m];
}
inline void work() {
ycl();
cc[0] = n - 1;
while (cc[nlg] >> 1) cc[nlg + 1] = cc[nlg] >> 1, ++nlg;
std::reverse(cc, cc + nlg + 1);
for (int i = 0; i < nlg; ++i) {
pw[i][0] = 1;
for (int j = 1; j <= n; ++j) pw[i][j] = (ll)pw[i][j - 1] * cc[i] % P;
}
printf("%I64d\n", (ll)Calc_Strling1(n - 1, Aa + Bb - 2) * C(Aa + Bb - 2, Aa - 1) % P);
}
inline void init() {
read(n), read(Aa), read(Bb);
}
int main() {
#ifdef hzhkk
freopen("hkk.in", "r", stdin);
#endif
init();
work();
fclose(stdin), fclose(stdout);
return 0;
}