CF960G
题意
问有多少个 \(n\) 个数的排列使得前缀最大值有 \(A\) 个,后缀最大值有 \(B\) 个。
\(1\ \leq\ n\ \leq\ 10^5,\ 0\ \leq\ A,\ B\ \leq\ n\)
做法1
找到排列中 \(n\) 的位置,将排列分成两段,\(n\) 之前和 \(n\) 之后。
令 \(f(len,\ num)\) 表示长度为 \(len\) 的排列,前缀最大值有 \(num\) 个,则答案为 \(\sum_{i\ =\ 1}^n\ f(i\ -\ 1,\ A\ -\ 1)\ f(n\ -\ i,\ B\ -\ 1)\ \binom{n\ -\ 1}{i\ -\ 1}\)。观察 \(f\) 转移式,发现与无符号第一类斯特林数一样,从组合意义理解:将长度为 \(len\) 的排列分成 \(num\) 段,第 \(i\) 段是 \([\) 第 \(i\) 个前缀最大值 \(,\) 第 \(i\ +\ 1\) 个前缀最大值 \()\)。这与将 \(len\) 个球放入 \(num\) 个盒中,每个盒子有 \((siz\ -\ 1)!\) 中排列方法等价。再观察所求答案,可以发现 \(ans\ =\ f(n\ -\ 1,\ A\ +\ B\ -\ 2)\ \binom{A\ +\ B\ -\ 2}{A\ -\ 1}\)。从组合意义理解:把 \(n\ -\ 1\) 个小球先放入 \(A\ +\ B\ -\ 2\) 个盒中,再选出 \(A\ -\ 1\) 个盒子放在 \(n\) 前面。
求 \(f(n,\ m)\) 直接用 \(f(n,\ m)\ =\ [x^m]\ \Pi_{i\ =\ 0}^{n\ -\ 1}\ (x\ +\ i)\) 即可。
代码
#include <bits/stdc++.h>
#ifdef __WIN32
#define LLFORMAT "I64"
#else
#define LLFORMAT "ll"
#endif
using namespace std;
const int mod = 998244353, proot = 3;
int main() {
int n, A, B;
cin >> n >> A >> B;
if(!A || !B) { cout << 0 << endl; return 0; }
if(n == 1) { cout << 1 << endl; return 0; }
auto pow_mod = [&](int x, int n) {
int y = 1;
while(n) {
if(n & 1) y = (long long) y * x % mod;
x = (long long) x * x % mod;
n >>= 1;
}
return y;
};
function<void(vector<int> &, bool)> dft = [&](vector<int> &a, bool rev) {
int n = a.size();
for (int i = 0, j = 0; i < n; ++i) {
if(i < j) swap(a[i], a[j]);
for (int k = n >> 1; (j ^= k) < k; k >>= 1);
}
for (int hl = 1, l = 2; l <= n; hl = l, l <<= 1) {
int wn = pow_mod(proot, (mod - 1) / l);
if(rev) wn = pow_mod(wn, mod - 2);
for (int i = 0; i < n; i += l) for (int w = 1, j = 0; j < hl; ++j) {
int t = (long long) a[i + j + hl] * w % mod;
a[i + j + hl] = (a[i + j] - t) % mod;
a[i + j] = (a[i + j] + t) % mod;
w = (long long) w * wn % mod;
}
}
if(rev) {
int inv = pow_mod(n, mod - 2);
for (int i = 0; i < n; ++i) a[i] = (long long) a[i] * inv % mod;
}
return;
};
function<void(vector<int> &, vector<int> &)> mul = [&](vector<int> &a, vector<int> &b) {
int n = a.size(), m = b.size(), N = 1;
while(N <= n + m - 2) N <<= 1;
a.resize(N); b.resize(N);
dft(a, 0); dft(b, 0);
for (int i = 0; i < N; ++i) a[i] = (long long) a[i] * b[i] % mod;
dft(a, 1);
a.resize(n + m - 1);
return;
};
vector<vector<int> > f(1, vector<int>{0, 1});
for (int i = 1; i < n - 1; ++i) {
f.push_back(vector<int>{i, 1});
while(f.size() >= 2 && f.back().size() >= f[f.size() - 2].size()) {
mul(f[f.size() - 2], f.back());
f.pop_back();
}
}
while(f.size() >= 2) {
mul(f[f.size() - 2], f.back());
f.pop_back();
}
if(A + B - 2 >= f[0].size()) { cout << 0 << endl; return 0; }
int ans = f[0][A + B - 2];
for (int i = 1; i <= A - 1; ++i) ans = (long long) ans * pow_mod(i, mod - 2) % mod * (A + B - i - 1) % mod;
cout << (ans + mod) % mod << endl;
return 0;
}