CF1528F AmShZ Farm

太神秘了。

考虑这个题目的限制有一个经典的转化:第 i 个人要坐到位置 ai,如果 ai 有人就继续查看 ai+1n 个人之后前 n 个位置坐满就合法。

先考虑如何对合法的序列计数。这个模型的套路是在最后加一个位置,把序列首位相接形成一个长度为 n+1 的环,每次如果 ai 有人就往环上下一个走。这样最后空出来的位置是 n+1 才合法,而显然每个位置被空出来的概率是相等的,所以合法的序列数为 (n+1)nn+1

再考虑这题,我们发现和上面的计数一样:可以统计出所有序列的答案再除以 n+1。我们发现,共 n+1 种数的贡献是相同的,我们只算一种数的贡献就是答案。一种数的贡献容易列出式子:

i=1n(ni)iknni

然后这个 ik 你一看就会想到第二类斯特林数。后面的过程不难,转下降幂推一推就好了。

#include <bits/stdc++.h>

using namespace std;

const int N = 4e5 + 5, mod = 998244353, G = 3;

int rev[N];

inline int power(int a, int b) {
	int k = b, y = a, t = 1;
	while (k) {
		if (k & 1) t = (1ll * y * t) % mod;
		y = (1ll * y * y) % mod; k >>= 1;
	} return t;
}

const int Gi = power(G, mod - 2);

struct poly {
	int len;
	vector<int> x;
	
	inline void NTT(int flag) {
		for (int i = 0; i <= len; ++i)
			if (i < rev[i]) swap(x[i], x[rev[i]]);
		for (int mid = 1; mid < len; mid <<= 1) {
			const int Wn = power(flag == 1 ? G : Gi, (mod - 1) / (mid << 1));
			for (int l = 0; l < len; l += mid << 1) {
				int w = 1;
				for (int t = 0; t < mid; ++t, w = (1ll * w * Wn) % mod) {
					int a = x[l + t], b = (1ll * w * x[l + mid + t]) % mod;
					x[l + t] = (a + b) % mod;
					x[l + mid + t] = ((a - b) % mod + mod) % mod;
				}
			}
		}
	}
};

inline poly mul(poly a, poly b) {
	poly c; int len = a.len + b.len;
	int tmp = 1, T = 0;
	while (tmp <= len) tmp <<= 1, ++T;
	for (int i = 1; i <= tmp; ++i)
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << T - 1);
	c.x.resize(N); c.len = tmp;
	a.len = tmp; b.len = tmp;
	const int inv = power(tmp, mod - 2);
	a.NTT(1); b.NTT(1);
	for (int i = 0; i <= a.len; ++i) c.x[i] = (1ll * a.x[i] * b.x[i]) % mod;
	c.NTT(-1);
	for (int i = 0; i <= a.len; ++i) c.x[i] = (1ll * c.x[i] * inv) % mod;
	c.len = len;
	return c;
}

poly a, b, c;

int fac[N], ifac[N], n, k;

inline void init() {
	a.x.resize(N); b.x.resize(N); a.len = b.len = k;
	fac[0] = 1; for (int i = 1; i <= k; ++i) fac[i] = (1ll * fac[i - 1] * i) % mod;
	ifac[k] = power(fac[k], mod - 2);
	for (int i = k - 1; ~i; --i) ifac[i] = (1ll * (i + 1) * ifac[i + 1]) % mod;
	for (int i = 0; i <= k; ++i) {
		a.x[i] = (1ll * power(i, k) * ifac[i]) % mod;
		b.x[i] = ifac[i]; if (i & 1) b.x[i] = -b.x[i] + mod;
		if (b.x[i] >= mod) b.x[i] -= mod;
	} c = mul(a, b);
}

int main() {
	scanf("%d%d", &n, &k); init();
	int res = 0;
	for (int i = 0, C = 1; i <= k && i <= n; C = 1ll * C * (n - i) % mod, ++i) {
		int del = 1ll * C * c.x[i] % mod;
		del = 1ll * del * power(n + 1, n - i) % mod;
		res += del; if (res >= mod) res -= mod;
	} printf("%d\n", res);
	return 0;
}
posted @   Smallbasic  阅读(19)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 无需6万激活码!GitHub神秘组织3小时极速复刻Manus,手把手教你使用OpenManus搭建本
· Manus爆火,是硬核还是营销?
· 终于写完轮子一部分:tcp代理 了,记录一下
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
点击右上角即可分享
微信分享提示