CF960G

题意

问有多少个 \(n\) 个数的排列使得前缀最大值有 \(A\) 个,后缀最大值有 \(B\) 个。

\(1\ \leq\ n\ \leq\ 10^5,\ 0\ \leq\ A,\ B\ \leq\ n\)

做法1

找到排列中 \(n\) 的位置,将排列分成两段,\(n\) 之前和 \(n\) 之后。
\(f(len,\ num)\) 表示长度为 \(len\) 的排列,前缀最大值有 \(num\) 个,则答案为 \(\sum_{i\ =\ 1}^n\ f(i\ -\ 1,\ A\ -\ 1)\ f(n\ -\ i,\ B\ -\ 1)\ \binom{n\ -\ 1}{i\ -\ 1}\)。观察 \(f\) 转移式,发现与无符号第一类斯特林数一样,从组合意义理解:将长度为 \(len\) 的排列分成 \(num\) 段,第 \(i\) 段是 \([\)\(i\) 个前缀最大值 \(,\)\(i\ +\ 1\) 个前缀最大值 \()\)。这与将 \(len\) 个球放入 \(num\) 个盒中,每个盒子有 \((siz\ -\ 1)!\) 中排列方法等价。再观察所求答案,可以发现 \(ans\ =\ f(n\ -\ 1,\ A\ +\ B\ -\ 2)\ \binom{A\ +\ B\ -\ 2}{A\ -\ 1}\)。从组合意义理解:把 \(n\ -\ 1\) 个小球先放入 \(A\ +\ B\ -\ 2\) 个盒中,再选出 \(A\ -\ 1\) 个盒子放在 \(n\) 前面。
\(f(n,\ m)\) 直接用 \(f(n,\ m)\ =\ [x^m]\ \Pi_{i\ =\ 0}^{n\ -\ 1}\ (x\ +\ i)\) 即可。

代码

#include <bits/stdc++.h>

#ifdef __WIN32
#define LLFORMAT "I64"
#else
#define LLFORMAT "ll"
#endif

using namespace std;

const int mod = 998244353, proot = 3;

int main() {
	int n, A, B;
	cin >> n >> A >> B;

	if(!A || !B) { cout << 0 << endl; return 0; }
	if(n == 1) { cout << 1 << endl; return 0; }

	auto pow_mod = [&](int x, int n) {
		int y = 1;
		while(n) {
			if(n & 1) y = (long long) y * x % mod;
			x = (long long) x * x % mod;
			n >>= 1;
		}
		return y;
	};

	function<void(vector<int> &, bool)> dft = [&](vector<int> &a, bool rev) {
		int n = a.size();
		for (int i = 0, j = 0; i < n; ++i) {
			if(i < j) swap(a[i], a[j]);
			for (int k = n >> 1; (j ^= k) < k; k >>= 1);
		}
		for (int hl = 1, l = 2; l <= n; hl = l, l <<= 1) {
			int wn = pow_mod(proot, (mod - 1) / l);
			if(rev) wn = pow_mod(wn, mod - 2);
			for (int i = 0; i < n; i += l) for (int w = 1, j = 0; j < hl; ++j) {
				int t = (long long) a[i + j + hl] * w % mod;
				a[i + j + hl] = (a[i + j] - t) % mod;
				a[i + j] = (a[i + j] + t) % mod;
				w = (long long) w * wn % mod;
			}
		}
		if(rev) {
			int inv = pow_mod(n, mod - 2);
			for (int i = 0; i < n; ++i) a[i] = (long long) a[i] * inv % mod;
		}
		return;
	};

	function<void(vector<int> &, vector<int> &)> mul = [&](vector<int> &a, vector<int> &b) {
		int n = a.size(), m = b.size(), N = 1;
		while(N <= n + m - 2) N <<= 1;
		a.resize(N); b.resize(N);
		dft(a, 0); dft(b, 0);
		for (int i = 0; i < N; ++i) a[i] = (long long) a[i] * b[i] % mod;
		dft(a, 1);
		a.resize(n + m - 1);
		return;
	};

	vector<vector<int> > f(1, vector<int>{0, 1});
	for (int i = 1; i < n - 1; ++i) {
		f.push_back(vector<int>{i, 1});
		while(f.size() >= 2 && f.back().size() >= f[f.size() - 2].size()) {
			mul(f[f.size() - 2], f.back());
			f.pop_back();
		}
	}
	while(f.size() >= 2) {
		mul(f[f.size() - 2], f.back());
		f.pop_back();
	}
	if(A + B - 2 >= f[0].size()) { cout << 0 << endl; return 0; }
	int ans = f[0][A + B - 2];
	for (int i = 1; i <= A - 1; ++i) ans = (long long) ans * pow_mod(i, mod - 2) % mod * (A + B - i - 1) % mod;
	cout << (ans + mod) % mod << endl;
	return 0;
}
posted @ 2018-10-13 11:39  King_George  阅读(200)  评论(0编辑  收藏  举报