LY1117 [ 20230228 CQYC模拟赛 T4 ] 鸥
题意
你有一个长为 \(n\) 的序列 \(a_n, a_i = i^k\)
从这个序列里抽出 \(m\) 个互不相同的数,求出这些数中最大值的期望。
请输出答案对 \(998244353\) 取模的结果。
\(8MB\)。
Sol
首先将 \(n,m\) 自减一。
很显然,我们需要求出 \(\dbinom{n+1}{m+1}^{-1}\displaystyle \sum_{i=m}^n \dbinom{i}{m}(i+1)^k\)。
考虑将普通幂转成上升幂,有
不妨将组合数展开,后面的式子将上升幂转成下降幂
合并下降幂
接下来我们换 \(i,j\) 的求和顺序
考虑快速求出后面的式子 \(\displaystyle \sum_{i=m}^n (i+j)^{\underline{j+m}}\),将上下标集体加 \(j\) 得到 $\displaystyle \sum_{i=m+j}{n+j}i{\underline{j+m}} $,根据有限积分,我们可以求出该式等于 \(\dfrac{(n+j+1)^{\underline{j+m+1}}-(m+j)^{\underline{j+m+1}}}{j+m+1}\)
容易发现 \((m+j)^{\underline{j+m+1}}=0\),我们要求的式子也就是
这里我们可以 \(O(Tk)\) 求出来了,但是我们的空间限制只有 8MiB,所以我们还要进一步化简。
那么有
对于每个询问,我们按 \(k\) 排序,滚动斯特林数并 \(O(k)\) 求出一行。
至于阶乘,我们有两种算法:
-
对阶乘及其逆元进行分块,每 \(\sqrt n\) 个位置记录这个位置的阶乘和逆元,接下来阶乘顺推,逆元逆推即可,时间复杂度 \(O(T\sqrt n+Tk+k^2+n)\)。
-
对于每个询问,我们需要用到的阶乘及其逆元只有 \(n!,m!,\dfrac{1}{(m+k+1)!}\),然后我们处理这些阶乘及其逆元,时间复杂度 \(O(Tk+k^2+n)\)。
Code
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <array>
#include <vector>
#define int long long
#define tupl tuple <int, int, int>
using namespace std;
#ifdef ONLINE_JUDGE
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf[1 << 23], *p1 = buf, *p2 = buf, ubuf[1 << 23], *u = ubuf;
#endif
int read() {
int p = 0, flg = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-') flg = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
p = p * 10 + c - '0';
c = getchar();
}
return p * flg;
}
void write(int x) {
if (x < 0) {
x = -x;
putchar('-');
}
if (x > 9) {
write(x / 10);
}
putchar(x % 10 + '0');
}
const int N = 5e3 + 5, bsi = 4200, mod = 998244353;
int pow_(int x, int k, int p) {
int ans = 1;
while (k) {
if (k & 1) ans = ans * x % p;
x = x * x % p;
k >>= 1;
}
return ans;
}
array <int, N> fac, inv;
void init(int n) {
int tp = 1, cnt = 0;
fac[0] = 1;
for (int i = 1; i <= n; i++) {
tp = tp * i % mod;
if (!(i % bsi)) cnt++, fac[cnt] = tp;
}
tp = pow_(tp, mod - 2, mod);
inv[cnt] = tp;
for (int i = n - 1; i; i--) {
tp = tp * (i + 1) % mod;
if (!(i % bsi)) cnt--, inv[cnt] = tp;
}
}
int getfac(int n) {
int tp = fac[n / bsi];
for (int i = (n / bsi) * bsi + 1; i <= n; i++)
tp = tp * i % mod;
return tp;
}
int getinv(int n) {
int tp = inv[n / bsi + 1];
for (int i = (n / bsi + 1) * bsi - 1; i >= n; i--)
tp = tp * (i + 1) % mod;
return tp;
}
array <vector <tupl>, N> qrl;
array <int, N> s;
void Mod(int &x) {
if (x >= mod) x -= mod;
if (x < 0) x += mod;
}
array <int, N> ans, tp;
int C(int n, int m) {
if (n < m) return 0;
return getfac(n) * getinv(m) % mod * getinv(n - m) % mod;
}
signed main() {
freopen("gull.in", "r", stdin);
freopen("gull.out", "w", stdout);
init(2.1e7);
int m = read();
for (int i = 1; i <= m; i++) {
int n = read(), m = read(), k = read();
qrl[k].push_back(make_tuple(n - 1, m - 1, i));
}
s[0] = 1;
for (auto x : qrl[0])
ans[get <2>(x)] = 1;
for (int i = 1; i <= 5000; i++) {
for (int j = i; j >= 1; j--)
s[j] = s[j - 1] + j * s[j] % mod, Mod(s[j]);
if (i == 1)
s[0] = 0;
if (!qrl[i].size()) continue;
for (auto [n, m, id] : qrl[i]) {
if (n == m) {
::ans[id] = pow_(n + 1, i, mod);
continue;
}
int ans = 0;
int tp1 = getinv(n - m), tp2 = (i & 1 ? -1 : 1),
tp3 = getfac(n + 1), tp4 = getfac(m);
tp[i] = getinv(m + i + 1);
for (int j = i - 1; ~j; j--)
tp[j] = tp[j + 1] * (m + j + 2) % mod;
for (int j = 0; j <= i; j++) {
ans += tp2 * s[j] % mod * tp3 % mod * tp4 % mod * tp1 % mod * tp[j] % mod;
tp3 = tp3 * (j + n + 2) % mod;
tp4 = tp4 * (j + m + 1) % mod;
Mod(ans), tp2 *= -1;
}
ans = ans * getinv(m) % mod * pow_(C(n + 1, m + 1), mod - 2, mod) % mod;
::ans[id] = ans;
}
}
for (int i = 1; i <= m; i++)
write(ans[i]), puts("");
return 0;
}