CodeForces 1919E Counting Prefixes
考虑一个很类似的题。我们把正数和负数分开来考虑,最后用 \(0\) 连接一些连续段,形如 \(0 - \text{正} - 0 - \text{正} - 0 - \text{负}\)。
先考虑正数。设 \(f_{i, j}\) 为考虑了 \(\ge i\) 的正数,形成了 \(j\) 个连续段的方案数。设 \(i\) 的出现次数为 \(c_i\)。
那么之前的每个段两端都需要接一个 \(i\) 下来,两段之间也可以只用一个 \(i\) 连接。
特别地,如果已经考虑到了结尾位置 \(n\),右端不用接数。于是我们状态再记一个 \(f_{i, j, 0/1}\) 表示包含位置 \(n\) 的段是否出现。
那么对于 \(f_{i + 1, j, 0}\) 的转移,新的段数 \(k = c_i - j\) 可以直接被计算出来。转移系数是 \(c_i\) 个数分配给 \(j + 1\) 个空的插板。我们有:
对于 \(f_{i + 1, j, 1}\) 的转移,新的段数为 \(k = c_i - j + 1\)。有转移:
同样地考虑负数,设 \(g_{i, j}\) 为考虑了 \(\le -i\) 的负数,形成了 \(j\) 个连续段的方案数。转移类似。
计算答案,只用枚举正数的段数 \(i\),如果 \(0\) 不为 \(n\) 的前缀和,那么负数的段数为 \(c_0 - i\),否则为 \(c_0 - i - 1\)。讨论一下 \(n\) 的前缀和是给 \(0\),正数还是负数,再乘上给段选择位置的系数。所以:
直接计算 \(f, g\) 复杂度为 \(O(n^2)\)。但是注意到 \(f, g\) 初值只有 \(O(1)\) 个位置有值,每次转移一个位置会转移到固定的另一个位置,最后一维只会进行 \(0 \to 1\) 的转移不会进行 \(1 \to 0\) 的转移。所以 dp 数组每行有值的位置数是 \(\color{red}{O(1)}\) 的。如果我们用 unordered_map 把有值的位置存下来,复杂度将是 \(\color{red}{O(n)}\)。
code
// Problem: E. Counting Prefixes
// Contest: Codeforces - Hello 2024
// URL: https://codeforces.com/contest/1919/problem/E
// Memory Limit: 256 MB
// Time Limit: 1000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 5050;
const ll mod = 998244353;
inline ll qpow(ll b, ll p) {
ll res = 1;
while (p) {
if (p & 1) {
res = res * b % mod;
}
b = b * b % mod;
p >>= 1;
}
return res;
}
ll n, a[maxn], fac[maxn], ifac[maxn], b[maxn], c[maxn];
unordered_map<ll, ll> f[2][2], g[2][2];
inline ll C(ll n, ll m) {
if (n < m || n < 0 || m < 0) {
return 0;
} else {
return fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}
}
void solve() {
scanf("%lld", &n);
for (int i = 0; i <= n; ++i) {
b[i] = c[i] = 0;
}
ll L = 0, R = 0;
++b[0];
for (int i = 1; i <= n; ++i) {
scanf("%lld", &a[i]);
L = min(L, a[i]);
R = max(R, a[i]);
if (a[i] >= 0) {
++b[a[i]];
} else {
++c[-a[i]];
}
}
a[0] = 0;
sort(a, a + n + 1);
fac[0] = 1;
for (int i = 1; i <= n; ++i) {
fac[i] = fac[i - 1] * i % mod;
}
ifac[n] = qpow(fac[n], mod - 2);
for (int i = n - 1; ~i; --i) {
ifac[i] = ifac[i + 1] * (i + 1) % mod;
}
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 2; ++j) {
f[i][j].clear();
g[i][j].clear();
}
}
if (R) {
f[1][0][b[R]] = f[1][1][b[R]] = 1;
int o = 0;
for (int i = R - 1; i; --i, o ^= 1) {
f[o][0].clear();
f[o][1].clear();
for (pii p : f[o ^ 1][0]) {
ll j = p.fst, v = p.scd;
if (b[i] >= j + 1) {
int nj = b[i] - j;
f[o][0][nj] = (f[o][0][nj] + v * C(b[i] - 1, j)) % mod;
f[o][1][nj] = (f[o][1][nj] + v * C(b[i] - 1, j)) % mod;
}
}
for (pii p : f[o ^ 1][1]) {
ll j = p.fst, v = p.scd;
if (b[i] >= j) {
int nj = b[i] - j + 1;
f[o][1][nj] = (f[o][1][nj] + v * C(b[i] - 1, j - 1)) % mod;
}
}
}
} else {
f[0][0][0] = 1;
}
if (L) {
L = -L;
g[1][0][c[L]] = g[1][1][c[L]] = 1;
int o = 0;
for (int i = L - 1; i; --i, o ^= 1) {
g[o][0].clear();
g[o][1].clear();
for (pii p : g[o ^ 1][0]) {
ll j = p.fst, v = p.scd;
if (c[i] >= j + 1) {
int nj = c[i] - j;
g[o][0][nj] = (g[o][0][nj] + v * C(c[i] - 1, j)) % mod;
g[o][1][nj] = (g[o][1][nj] + v * C(c[i] - 1, j)) % mod;
}
}
for (pii p : g[o ^ 1][1]) {
ll j = p.fst, v = p.scd;
if (c[i] >= j) {
int nj = c[i] - j + 1;
g[o][1][nj] = (g[o][1][nj] + v * C(c[i] - 1, j - 1)) % mod;
}
}
}
} else {
g[0][0][0] = 1;
}
ll ans = 0;
for (int i = 0; i <= b[0]; ++i) {
ans = (ans + f[R & 1][1][i] * g[L & 1][0][b[0] - i] % mod * C(b[0] - 1, i - 1)) % mod;
ans = (ans + f[R & 1][0][i] * g[L & 1][1][b[0] - i] % mod * C(b[0] - 1, i)) % mod;
if (i < b[0]) {
ans = (ans + f[R & 1][0][i] * g[L & 1][0][b[0] - i - 1] % mod * C(b[0] - 1, i)) % mod;
}
}
printf("%lld\n", ans);
}
int main() {
int T = 1;
scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}