ARC138E

给定正整数 \(n,k\),对于一个长为 \(n\) 的序列 \(\left\{a_i\right\}\),我们称其是好的当且仅当:

  1. \(\forall i,1\le i\le n\),均满足 \(0\le a_i\le i\)
  2. \(\forall v=1,2,\cdots,n\),至多有一个下标 \(i\) 满足 \(a_i=v\)

求所有好序列中长为 \(k\) 的严格递减子序列的出现次数之和,答案对 \(10^9+7\) 取模。

第一步是巧妙的转化:对于 \(a_i\gt0\),连边 \(i\to a_i-1\)
对于每个好的序列 \(a\),连边后形成若干条链
假设选出的 \(k\) 个点为 \(b_1,b_2,\cdots,b_k\),并且 \(b_1\lt b_2\lt \cdots \lt b_k\),对应的 \(a_{b_1}-1,\cdots,a_{b_k}-1\) 为了方便记作 \(c_1,c_2,\cdots,c_k\),则有:
\(c_k\lt c_{k-1}\lt \cdots \lt c_1\lt b_1\lt b_2\lt\cdots\lt b_k\)
这可以由 \(b_1\gt c_1\)\(c\) 递减,\(b\) 递增推导出来。
也就是说,对于确定的 \(a\),在 \([0,n]\) 中选出 \(2k\) 个点,它们对应的子序列有且仅有一种
注意到这 \(k\) 条边一定恰好涉及到 \(k\) 条链,先删去这 \(k\) 条边,将其中不超过 \(c_1\) 的点集记为 \(A\),不小于 \(b_1\) 的点集记为 \(B\)
那么我们要计算分别将 \(A,B\) 拆成 \(k\) 条链的方案数。注意,一旦两边拆出 \(k\) 条链的方案确定了,连 \(k\) 条边的方案也就唯一确定了。
考虑枚举前 \(k\) 个点以及所在链上的其他点的个数 \(i\),后 \(k\) 个点以及所在链上的其他点的个数 \(j\),则方案数为:
\(\dbinom{n+1}{i+j}\)\(\begin{Bmatrix}i\\ k\end{Bmatrix}\begin{Bmatrix}j\\ k\end{Bmatrix}\sum\limits_{x=1}^{n+1-i-j}\begin{Bmatrix}n+1-i-j\\ x\end{Bmatrix}\)
\(i\) 个点分成 \(j\) 条链的方案数等价于第二类斯特林数,因为链上总是编号大的指向编号小的。
注意到 \(\sum\limits_{x=1}^{n+1-i-j}\begin{Bmatrix}n+1-i-j\\ x\end{Bmatrix}\) 其实就是贝尔数 \(B_{n+1-i-j}\)
上面这些组合数,斯特林数,贝尔数都可以 \(\mathcal O(n^2)\) 预处理出来。

Code:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 5005, mod = 1e9 + 7;
int n, k;
int fac[N], inv[N];
int S[N][N], B[N];

int qpow(int x, int y) {
	int res = 1;
	while (y) {
		if (y & 1) res = 1ll * res * x % mod;
		x = 1ll * x * x % mod;
		y >>= 1;
	}
	return res;
}

void init(int n) {
	fac[0] = 1;
	for (int i = 1; i <= n; ++i) fac[i] = 1ll * fac[i - 1] * i % mod;
	inv[n] = qpow(fac[n], mod - 2);
	for (int i = n - 1; ~i; --i) inv[i] = 1ll * inv[i + 1] * (i + 1) % mod;
	S[0][0] = 1;
	for (int i = 1; i <= n; ++i) {
		S[i][1] = S[i][i] = 1;
		for (int j = 2; j < i; ++j)
			S[i][j] = (S[i - 1][j - 1] + 1ll * S[i - 1][j] * j % mod) % mod;
	}
	for (int i = 0; i <= n; ++i) {
		for (int j = 0; j <= i; ++j) B[i] = (B[i] + S[i][j]) % mod;
	}
}

int C(int n, int m) {
	if (n < 0 || m < 0 || n < m) return 0;
	return 1ll * fac[n] * inv[n - m] % mod * inv[m] % mod;
}

int main() {
	scanf("%d%d", &n, &k);
	init(n + 1);
	int ans = 0;
	for (int i = k; i <= n + 1; ++i)
		for (int j = k; j <= n + 1 - i; ++j) {
			int tmp = C(n + 1, i + j);
			tmp = 1ll * tmp * S[i][k] % mod;
			tmp = 1ll * tmp * S[j][k] % mod;
			tmp = 1ll * tmp * B[n + 1 - (i + j)] % mod;
			ans = (ans + tmp) % mod;
		}
	printf("%d", ans);
	return 0;
}
posted @ 2022-10-04 09:00  Kobe303  阅读(20)  评论(0编辑  收藏  举报