CodeForces 1943D2 Counting Is Fun (Hard Version)
被自己的赛时智障操作气笑了。谁告诉你容斥钦定了几个要记到状态里面的。。。/tuu
显然先找“好数组”的充要条件。对原数组 \(a\) 差分,设 \(b_i = a_i - a_{i - 1}\)。那么一次可以选择一对 \((i, j)\) 满足 \(i \le j - 2\),然后给 \(b_i\) 减 \(1\),给 \(b_j\) 加 \(1\)。
我们从左往右操作。注意到我们不能操作相邻的一对元素,所以若某个时刻 \(b_i > 0\) 且 \(b_{1 \sim i - 2}\) 都为 \(0\) 就不合法。这就是充要条件。
充要条件可以表述成 \(b_i + \sum\limits_{j = 1}^{i - 2} b_j \ge 0\),即 \(a_i - a_{i - 1} + a_{i - 2} \ge 0\),注意 \(a_0 = a_{n + 1} = 0\)。
对于 \(n \le 400\) 可以直接做一个 \(O(n^3)\) dp,设 \(f_{i, j, k}\) 为考虑了 \([1, i]\) 的前缀,\(a_{i - 1} = j, a_i = k\) 的方案数。因为合法的 \(a_{i - 2}\) 是一段后缀,所以预处理后缀和即可做到 \(O(1)\) 转移。这样可以通过 D1。
对于 \(n \le 3000\) 显然不能这么做了。发现 dp 状态数都成瓶颈了,换个思路,考虑容斥,钦定一些位置 \(i\) 是满足 \(a_i > a_{i - 1} + a_{i + 1}\)(显然这些位置不会相邻),容斥系数就是 \(-1\) 的位置个数次方。
对于 \(i\) 被钦定的方案,注意到我们不用枚举 \(a_i\),只要知道 \(a_{i - 1}\) 和 \(a_{i + 1}\),就能算出 \(a_i\) 的取值个数为 \(m - a_{i - 1} - a_{i + 1}\)(令题面中的 \(k\) 为 \(m\))。
所以设 \(f_{i, j}\) 为考虑了 \([1, i]\) 的前缀,\(a_i = j\),每种方案乘上容斥系数的和。那么有 \(f_{i, j} = \sum\limits_{k = 0}^m f_{i - 1, k} - (m - j - k) f_{i - 2, k}\),容易前缀和优化。
时间复杂度 \(O(n^2)\)。
code
// Problem: D2. Counting Is Fun (Hard Version)
// Contest: Codeforces - Codeforces Round 934 (Div. 1)
// URL: https://codeforces.com/contest/1943/problem/D2
// Memory Limit: 1024 MB
// Time Limit: 3000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 3030;
ll n, m, mod, f[maxn][maxn], g[maxn], h[maxn];
void solve() {
scanf("%lld%lld%lld", &n, &m, &mod);
f[0][0] = 1;
for (int i = 1; i <= n + 1; ++i) {
ll s = 0;
for (int j = 0; j <= m; ++j) {
s = (s + f[i - 1][j]) % mod;
}
for (int j = 0; j <= m; ++j) {
f[i][j] = s;
}
if (i >= 2) {
for (int j = 0; j <= m; ++j) {
g[j] = f[i - 2][j];
h[j] = f[i - 2][j] * j % mod;
if (j) {
g[j] = (g[j] + g[j - 1]) % mod;
h[j] = (h[j] + h[j - 1]) % mod;
}
}
for (int j = 0; j <= m; ++j) {
ll x = (g[m - j] * (m - j) - h[m - j] + mod) % mod;
f[i][j] = (f[i][j] - x + mod) % mod;
}
}
}
printf("%lld\n", f[n + 1][0]);
}
int main() {
int T = 1;
scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}