CF960G Bandit Blues 分治+NTT(第一类斯特林数)
$ \color{#0066ff}{ 题目描述 }$
给你三个正整数 \(n\),\(a\),\(b\),定义 \(A\) 为一个排列中是前缀最大值的数的个数,定义 \(B\) 为一个排列中是后缀最大值的数的个数,求长度为 \(n\) 的排列中满足 \(A = a\) 且 \(B = b\) 的排列个数。\(n \le 10^5\),答案对 \(998244353\) 取模。
\(\color{#0066ff}{输入格式}\)
三个整数n,a,b
\(\color{#0066ff}{输出格式}\)
方案数
\(\color{#0066ff}{输入样例}\)
1 1 1
2 1 1
2 2 1
5 2 2
\(\color{#0066ff}{输出样例}\)
1
0
1
22
\(\color{#0066ff}{数据范围与提示}\)
\(N\) ( $1<=N<=10^{5} $ ), $A $and $ B$ ( $0<=A,B<=N $).
\(\color{#0066ff}{题解}\)
显然当\(a+b-1>n\)时,无解
考虑DP, \(f[i][j]\)表示i的排列有j个前缀最大值的方案数
考虑枚举1的位置\(f[i][j] = f[i-1][j-1]+(i-1)*f[i-1][j]\)
这是第一类斯特林数
实际上第一维滚动之后,可以发现, 就是把整个数组移动一位再加上自己的值*dp轮数
就相当于第i轮有i次操作,有1的方案取一个球,有i-1的方案一个球不取
于是构造生成函数\(\begin{aligned}\prod_{i=0}^{n-2}(x+i)\end{aligned}\)
本来应该是n-1的,实际上n的方案从1只转移n-1次, 所以-1
这个式子可以分治+NTT快速求出
最后还要组合一下,可以考虑每个产生贡献的值,a个和b个分别分成a-1和b-1段
最后再乘一个\(C_{a+b-2}^{a-1}\)即可
#include<bits/stdc++.h>
#define LL long long
LL in() {
char ch; LL x = 0, f = 1;
while(!isdigit(ch = getchar()))(ch == '-') && (f = -f);
for(x = ch ^ 48; isdigit(ch = getchar()); x = (x << 1) + (x << 3) + (ch ^ 48));
return x * f;
}
const int maxn = 4e5 + 10;
const int mod = 998244353;
int len, r[maxn];
using std::vector;
LL ksm(LL x, LL y) {
LL re = 1LL;
while(y) {
if(y & 1) re = re * x % mod;
x = x * x % mod;
y >>= 1;
}
return re;
}
void FNTT(vector<int> &A, int flag) {
A.resize(len);
for(int i = 0; i < len; i++) if(i < r[i]) std::swap(A[i], A[r[i]]);
for(int l = 1; l < len; l <<= 1) {
int w0 = ksm(3, (mod - 1) / (l << 1));
for(int i = 0; i < len; i += (l << 1)) {
int w = 1, a0 = i, a1 = i + l;
for(int k = 0; k < l; k++, a0++, a1++, w = 1LL * w0 * w % mod) {
int tmp = 1LL * A[a1] * w % mod;
A[a1] = ((A[a0] - tmp) % mod + mod) % mod;
A[a0] = (A[a0] + tmp) % mod;
}
}
}
if(!(~flag)) {
std::reverse(A.begin() + 1, A.end());
int inv = ksm(len, mod - 2);
for(int i = 0; i < len; i++) A[i] = 1LL * A[i] * inv % mod;
}
}
vector<int> operator * (vector<int> A, vector<int> B) {
int tot = A.size() + B.size() - 1;
for(len = 1; len <= tot; len <<= 1);
for(int i = 0; i < len; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) * (len >> 1));
FNTT(A, 1), FNTT(B, 1);
std::vector<int> ans;
for(int i = 0; i < len; i++) ans.push_back(1LL * A[i] * B[i] % mod);
FNTT(ans, -1);
ans.resize(tot);
return ans;
}
int n, a, b;
std::vector<int> work(int l, int r) {
vector<int> ans;
if(l == r) {
ans.resize(2, 0);
ans[1] += 1, ans[0] += l;
return ans;
}
int mid = (l + r) >> 1;
return work(l, mid) * work(mid + 1, r);
}
LL C(int x, int y) {
LL ans1 = 1, ans2 = 1;
for(int i = y + 1; i <= x; i++) ans1 = 1LL * ans1 * i % mod;
for(int i = 1; i <= x - y; i++) ans2 = 1LL * ans2 * i % mod;
return 1LL * ans1 * ksm(ans2, mod - 2) % mod;
}
int main() {
n = in(), a = in(), b = in();
if(!a || !b || a + b - 1 > n) return puts("0"), 0;
if(n == 1) return puts("1"), 0;
std::vector<int> ans;
ans = work(0, n - 2);
printf("%lld\n", 1LL * ans[a + b - 2] * C(a + b - 2, a - 1) % mod);
return 0;
}
----olinr