LY1117 [ 20230228 CQYC模拟赛 T4 ] 鸥

题意

你有一个长为 \(n\) 的序列 \(a_n, a_i = i^k\)

从这个序列里抽出 \(m\) 个互不相同的数,求出这些数中最大值的期望。

请输出答案对 \(998244353\) 取模的结果。

\(8MB\)

Sol

首先将 \(n,m\) 自减一。

很显然,我们需要求出 \(\dbinom{n+1}{m+1}^{-1}\displaystyle \sum_{i=m}^n \dbinom{i}{m}(i+1)^k\)

考虑将普通幂转成上升幂,有

\[\dbinom{n+1}{m+1}^{-1}\sum_{i=m}^n \dbinom{i}{m}\sum_{j=0}^k (-1)^{k-j}\begin{Bmatrix} k\\j\end{Bmatrix}(i+1)^{\bar j} \]

不妨将组合数展开,后面的式子将上升幂转成下降幂

\[\dfrac{(m+1)!(n-m)!}{(n+1)!}\dfrac{1}{m!}\sum_{i=m}^n i^{\underline{m}}\sum_{j=0}^k (-1)^{k-j}\begin{Bmatrix} k\\j \end{Bmatrix} (i+j)^{\underline{j}} \]

合并下降幂

\[\dfrac{(m+1)(n-m)!}{(n+1)!}\sum_{i=m}^n \sum_{j=0}^k (-1)^{k-j}\begin{Bmatrix}k\\ j\end{Bmatrix}(i+j)^{\underline{j+m}} \]

接下来我们换 \(i,j\) 的求和顺序

\[\dfrac{(m+1)(n-m)!}{(n+1)!}\sum_{j=0}^k (-1)^{k-j}\begin{Bmatrix}k\\ j\end{Bmatrix}\sum_{i=m}^n (i+j)^{\underline{j+m}} \]

考虑快速求出后面的式子 \(\displaystyle \sum_{i=m}^n (i+j)^{\underline{j+m}}\),将上下标集体加 \(j\) 得到 $\displaystyle \sum_{i=m+j}{n+j}i{\underline{j+m}} $,根据有限积分,我们可以求出该式等于 \(\dfrac{(n+j+1)^{\underline{j+m+1}}-(m+j)^{\underline{j+m+1}}}{j+m+1}\)

容易发现 \((m+j)^{\underline{j+m+1}}=0\),我们要求的式子也就是

\[\dfrac{(m+1)(n-m)!}{(n+1)!}\sum_{j=0}^{k}(-1)^{k-j}\begin{Bmatrix}k\\ j\end{Bmatrix} \dfrac{(n+j+1)^{\underline{j+m+1}}}{j+m+1} \]

这里我们可以 \(O(Tk)\) 求出来了,但是我们的空间限制只有 8MiB,所以我们还要进一步化简。

\[\dfrac{(m+1)(n-m)!}{(n+1)!}\sum_{j=0}^k (-1)^{k-j}\begin{Bmatrix}k\\ j\end{Bmatrix} \dfrac{(n+j+1)!(j+m)!}{(n-m)!(j+m+1)!} \]

那么有

\[\dfrac{(m+1)}{(n+1)!}\sum_{j=0}^k (-1)^{k-j} \begin{Bmatrix}k\\j \end{Bmatrix}\dfrac{(n+j+1)!(j+m)!}{(j+m+1)!} \]

对于每个询问,我们按 \(k\) 排序,滚动斯特林数并 \(O(k)\) 求出一行。

至于阶乘,我们有两种算法:

  1. 对阶乘及其逆元进行分块,每 \(\sqrt n\) 个位置记录这个位置的阶乘和逆元,接下来阶乘顺推,逆元逆推即可,时间复杂度 \(O(T\sqrt n+Tk+k^2+n)\)

  2. 对于每个询问,我们需要用到的阶乘及其逆元只有 \(n!,m!,\dfrac{1}{(m+k+1)!}\),然后我们处理这些阶乘及其逆元,时间复杂度 \(O(Tk+k^2+n)\)

参考链接:1 2 3

Code


#include <iostream>
#include <algorithm>
#include <cstdio>
#include <array>
#include <vector>
#define int long long
#define tupl tuple <int, int, int>
using namespace std;
#ifdef ONLINE_JUDGE

#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf[1 << 23], *p1 = buf, *p2 = buf, ubuf[1 << 23], *u = ubuf;

#endif
int read() {
	int p = 0, flg = 1;
	char c = getchar();
	while (c < '0' || c > '9') {
		if (c == '-') flg = -1;
		c = getchar();
	}
	while (c >= '0' && c <= '9') {
		p = p * 10 + c - '0';
		c = getchar();
	}
	return p * flg;
}
void write(int x) {
	if (x < 0) {
		x = -x;
		putchar('-');
	}
	if (x > 9) {
		write(x / 10);
	}
	putchar(x % 10 + '0');
}
const int N = 5e3 + 5, bsi = 4200, mod = 998244353;

int pow_(int x, int k, int p) {
	int ans = 1;
	while (k) {
		if (k & 1) ans = ans * x % p;
		x = x * x % p;
		k >>= 1;
	}
	return ans;
}

array <int, N> fac, inv;

void init(int n) {
	int tp = 1, cnt = 0;
	fac[0] = 1;
	for (int i = 1; i <= n; i++) {
		tp = tp * i % mod;
		if (!(i % bsi)) cnt++, fac[cnt] = tp;
	}
	tp = pow_(tp, mod - 2, mod);
	inv[cnt] = tp;
	for (int i = n - 1; i; i--) {
		tp = tp * (i + 1) % mod;
		if (!(i % bsi)) cnt--, inv[cnt] = tp;
	}
}

int getfac(int n) {
	int tp = fac[n / bsi];
	for (int i = (n / bsi) * bsi + 1; i <= n; i++)
		tp = tp * i % mod;
	return tp;
}

int getinv(int n) {
	int tp = inv[n / bsi + 1];
	for (int i = (n / bsi + 1) * bsi - 1; i >= n; i--)
		tp = tp * (i + 1) % mod;
	return tp;
}

array <vector <tupl>, N> qrl;
array <int, N> s;

void Mod(int &x) {
	if (x >= mod) x -= mod;
	if (x < 0) x += mod;
}

array <int, N> ans, tp;

int C(int n, int m) {
	if (n < m) return 0;
	return getfac(n) * getinv(m) % mod * getinv(n - m) % mod;
}

signed main() {
	freopen("gull.in", "r", stdin);
	freopen("gull.out", "w", stdout);
	init(2.1e7);
	int m = read();
	for (int i = 1; i <= m; i++) {
		int n = read(), m = read(), k = read();
		qrl[k].push_back(make_tuple(n - 1, m - 1, i));
	}
	s[0] = 1;
	for (auto x : qrl[0])
		ans[get <2>(x)] = 1;
	for (int i = 1; i <= 5000; i++) {
		for (int j = i; j >= 1; j--)
			s[j] = s[j - 1] + j * s[j] % mod, Mod(s[j]);
		if (i == 1)
			s[0] = 0;
		if (!qrl[i].size()) continue;
		for (auto [n, m, id] : qrl[i]) {
			if (n == m) {
				::ans[id] = pow_(n + 1, i, mod);
				continue;
			}
			int ans = 0;
			int tp1 = getinv(n - m), tp2 = (i & 1 ? -1 : 1),
				tp3 = getfac(n + 1), tp4 = getfac(m);
			tp[i] = getinv(m + i + 1);
			for (int j = i - 1; ~j; j--)
				tp[j] = tp[j + 1] * (m + j + 2) % mod;
			for (int j = 0; j <= i; j++) {
				ans += tp2 * s[j] % mod * tp3 % mod * tp4 % mod * tp1 % mod * tp[j] % mod;
				tp3 = tp3 * (j + n + 2) % mod;
				tp4 = tp4 * (j + m + 1) % mod;
				Mod(ans), tp2 *= -1;
			}
			ans = ans * getinv(m) % mod * pow_(C(n + 1, m + 1), mod - 2, mod) % mod;
			::ans[id] = ans;
		}
	}
	for (int i = 1; i <= m; i++)
		write(ans[i]), puts("");
	return 0;
}
posted @ 2024-01-05 09:48  cxqghzj  阅读(12)  评论(0编辑  收藏  举报