CF960G Bandit Blues 分治+NTT(第一类斯特林数)

$ \color{#0066ff}{ 题目描述 }$

给你三个正整数 \(n\)\(a\)\(b\),定义 \(A\) 为一个排列中是前缀最大值的数的个数,定义 \(B\) 为一个排列中是后缀最大值的数的个数,求长度为 \(n\) 的排列中满足 \(A = a\)\(B = b\) 的排列个数。\(n \le 10^5\),答案对 \(998244353\) 取模。

\(\color{#0066ff}{输入格式}\)

三个整数n,a,b

\(\color{#0066ff}{输出格式}\)

方案数

\(\color{#0066ff}{输入样例}\)

1 1 1

2 1 1

2 2 1

5 2 2

\(\color{#0066ff}{输出样例}\)

1
    
0
    
1
    
22

\(\color{#0066ff}{数据范围与提示}\)

\(N\) ( $1<=N<=10^{5} $ ), $A $and $ B$ ( $0<=A,B<=N $).

\(\color{#0066ff}{题解}\)

显然当\(a+b-1>n\)时,无解

考虑DP, \(f[i][j]\)表示i的排列有j个前缀最大值的方案数

考虑枚举1的位置\(f[i][j] = f[i-1][j-1]+(i-1)*f[i-1][j]\)

这是第一类斯特林数

实际上第一维滚动之后,可以发现, 就是把整个数组移动一位再加上自己的值*dp轮数

就相当于第i轮有i次操作,有1的方案取一个球,有i-1的方案一个球不取

于是构造生成函数\(\begin{aligned}\prod_{i=0}^{n-2}(x+i)\end{aligned}\)

本来应该是n-1的,实际上n的方案从1只转移n-1次, 所以-1

这个式子可以分治+NTT快速求出

最后还要组合一下,可以考虑每个产生贡献的值,a个和b个分别分成a-1和b-1段

最后再乘一个\(C_{a+b-2}^{a-1}\)即可

#include<bits/stdc++.h>
#define LL long long
LL in() {
	char ch; LL x = 0, f = 1;
	while(!isdigit(ch = getchar()))(ch == '-') && (f = -f);
	for(x = ch ^ 48; isdigit(ch = getchar()); x = (x << 1) + (x << 3) + (ch ^ 48));
	return x * f;
}
const int maxn = 4e5 + 10;
const int mod = 998244353;
int len, r[maxn];
using std::vector;
LL ksm(LL x, LL y) {
	LL re = 1LL;
	while(y) {
		if(y & 1) re = re * x % mod;
		x = x * x % mod;
		y >>= 1;
	}
	return re;
}
void FNTT(vector<int> &A, int flag) {
	A.resize(len);
	for(int i = 0; i < len; i++) if(i < r[i]) std::swap(A[i], A[r[i]]);
	for(int l = 1; l < len; l <<= 1) {
		int w0 = ksm(3, (mod - 1) / (l << 1));
		for(int i = 0; i < len; i += (l << 1)) {
			int w = 1, a0 = i, a1 = i + l;
			for(int k = 0; k < l; k++, a0++, a1++, w = 1LL * w0 * w % mod) {
				int tmp = 1LL * A[a1] * w % mod;
				A[a1] = ((A[a0] - tmp) % mod + mod) % mod;
				A[a0] = (A[a0] + tmp) % mod;
			}
		}
	}
	if(!(~flag)) {
		std::reverse(A.begin() + 1, A.end());
		int inv = ksm(len, mod - 2);
		for(int i = 0; i < len; i++) A[i] = 1LL * A[i] * inv % mod;
	}
}
vector<int> operator * (vector<int> A, vector<int> B) {
	int tot = A.size() + B.size() - 1;
	for(len = 1; len <= tot; len <<= 1);
	for(int i = 0; i < len; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) * (len >> 1));
	FNTT(A, 1), FNTT(B, 1);
	std::vector<int> ans;
	for(int i = 0; i < len; i++) ans.push_back(1LL * A[i] * B[i] % mod);
	FNTT(ans, -1);
	ans.resize(tot);
	return ans;
}
int n, a, b;
std::vector<int> work(int l, int r) {
	vector<int> ans;
	if(l == r) {
		ans.resize(2, 0);
		ans[1] += 1, ans[0] += l;
		return ans;
	}
	int mid = (l + r) >> 1;
    return work(l, mid) * work(mid + 1, r);
}
LL C(int x, int y) {
	LL ans1 = 1, ans2 = 1;
	for(int i = y + 1; i <= x; i++) ans1 = 1LL * ans1 * i % mod;
	for(int i = 1; i <= x - y; i++) ans2 = 1LL * ans2 * i % mod;
	return 1LL * ans1 * ksm(ans2, mod - 2) % mod;
}
int main() {
	n = in(), a = in(), b = in();
	if(!a || !b || a + b - 1 > n) return puts("0"), 0;
	if(n == 1) return puts("1"), 0;
	std::vector<int> ans;
	ans = work(0, n - 2);
	printf("%lld\n", 1LL * ans[a + b - 2] * C(a + b - 2, a - 1) % mod);
	return 0;
}
posted @ 2019-03-22 20:47  olinr  阅读(280)  评论(0编辑  收藏  举报