P5299 [PKUWC2018] Slay the Spire (dp/组合计数)
P5299 [PKUWC2018] Slay the Spire
dp/组合计数
先考虑选出 \(m\) 张牌之后,怎么出牌最优。首先显然的,若选出 \(k\) 张牌,\(x\) 张强化牌一定是前 \(x\) 大的 \(a_i\),\(y\) 张攻击牌一定是前 \(y\) 大的 \(b_i\),并且肯定先用完强化牌再攻击,伤害为 \(\prod_{i=1}^xa_i\sum_{i=1}^yb_i\)(此时排序了)。如果 \(x+1\) 张强化牌,那么伤害就是 \(\prod_{i=1}^{x+1}a_i\sum_{i=1}^{y-1}b_i\)。比较哪个策略更优,作差。整理得 \((a_{x+1}-1)\sum_{i=1}^{y-1}b_i\ge b_{y}\),不等式在 \(y>1\) 时恒成立。所以策略就是能出强化牌就出,至少留下一张攻击牌。
那么就可做了。考虑计算所有方案的伤害总和。分成两种情况讨论,选出的 \(m\) 张牌中强化牌小于 \(k-1\),那么出完强化牌就出攻击牌;反之,只出前 \(k-1\) 张强化,一张攻击。可以发现这些是可以预处理的。设 \(f_{i,j,0/1}\) 表示前 \(i\) 张强化牌选 \(j\) 张的乘积和(必选第 \(i\) 张/不必选第 \(i\) 张),\(g_{i,j,0/1}\) 表示前 \(i\) 张攻击牌选 \(j\) 张的和的和(必选第 \(i\) 张/不必选第 \(i\) 张)。转移易得。
第一种情况。考虑枚举强化牌数 \(i\) 和最后一张攻击牌位置 \(j\),贡献是 \(f_{n,i,1}\times g_{j,k-i,0}\times C(n-j,m-k)\)。
第二种情况。考虑枚举最后一张强化牌位置 \(i\) 和唯一一张攻击牌位置 \(j\),贡献是 \(f_{i,k-1,0}\times b_j\times C(2\times n-i-j,m-k)\)。
复杂度 \(O(n^2)\)。
#include <bits/stdc++.h>
#define pii std::pair<int, int>
#define fi first
#define se second
#define pb push_back
using i64 = long long;
using ull = unsigned long long;
const i64 iinf = 0x3f3f3f3f, linf = 0x3f3f3f3f3f3f3f3f;
const int N = 3e3 + 10, mod = 998244353;
i64 n, m, k, ans;
i64 a[N], b[N];
i64 f[N][N][2], g[N][N][2];
i64 qpow(i64 a, i64 b, i64 m) {
i64 ret = 1;
while(b) {
if(b & 1) ret = ret * a % m;
a = a * a % m;
b >>= 1;
}
return ret;
}
struct BIN {
i64 fac[N], inv[N];
void init(int n) {
fac[0] = 1;
for(int i = 1; i <= n; i++) fac[i] = fac[i - 1] * i % mod;
inv[n] = qpow(fac[n], mod - 2, mod);
for(int i = n - 1; i >= 0; i--) inv[i] = inv[i + 1] * (i + 1) % mod;
}
i64 C(i64 n, i64 m) {
if(n < m) return 0;
return fac[n] * inv[m] % mod * inv[n - m] % mod;
}
i64 C2(i64 n, i64 m) {
if(n < m) return 0;
return fac[n] * qpow(fac[m], mod - 2, mod) % mod * qpow(fac[n - m], mod - 2, mod) % mod;
}
i64 lucas(i64 n, i64 m){
if(!m) return 1;
return C2(n % mod, m % mod) * lucas(n / mod, m / mod) % mod;
}
} comb;
void solve() {
std::cin >> n >> m >> k;
for(int i = 1; i <= n; i++) {
std::cin >> a[i];
}
for(int i = 1; i <= n; i++) {
std::cin >> b[i];
}
std::sort(a + 1, a + n + 1, std::greater<int>());
std::sort(b + 1, b + n + 1, std::greater<int>());
f[0][0][0] = f[0][0][1] = 1;
for(int i = 1; i <= n; i++) {
f[i][0][1] = 1;
for(int j = 1; j <= i; j++) {
f[i][j][0] = a[i] * f[i - 1][j - 1][1] % mod;
f[i][j][1] = (f[i][j][0] + f[i - 1][j][1]) % mod;
g[i][j][0] = (b[i] * comb.C(i - 1, j - 1) % mod + g[i - 1][j - 1][1]) % mod;
g[i][j][1] = (g[i][j][0] + g[i - 1][j][1]) % mod;
}
}
for(int i = 0; i < k - 1; i++) {
for(int j = 1; j <= n; j++) {
ans = (ans + f[n][i][1] * g[j][k - i][0] % mod * comb.C(n - j, m - k) % mod) % mod;
}
}
for(int i = 0; i <= n; i++) {
for(int j = 1; j <= n; j++) {
ans = (ans + f[i][k - 1][0] * b[j] % mod * comb.C(2 * n - i - j, m - k) % mod) % mod;
}
}
for(int i = 1; i <= n; i++) for(int j = 1; j <= n; j++) f[i][j][0] = f[i][j][1] = g[i][j][0] = g[i][j][1] = 0;
std::cout << ans << "\n";
ans = 0;
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
comb.init(N - 10);
int t;
std::cin >> t;
while(t--) solve();
return 0;
}