Loading

P5299 [PKUWC2018] Slay the Spire (dp/组合计数)

P5299 [PKUWC2018] Slay the Spire

dp/组合计数

先考虑选出 \(m\) 张牌之后,怎么出牌最优。首先显然的,若选出 \(k\) 张牌,\(x\) 张强化牌一定是前 \(x\) 大的 \(a_i\)\(y\) 张攻击牌一定是前 \(y\) 大的 \(b_i\),并且肯定先用完强化牌再攻击,伤害为 \(\prod_{i=1}^xa_i\sum_{i=1}^yb_i\)(此时排序了)。如果 \(x+1\) 张强化牌,那么伤害就是 \(\prod_{i=1}^{x+1}a_i\sum_{i=1}^{y-1}b_i\)。比较哪个策略更优,作差。整理得 \((a_{x+1}-1)\sum_{i=1}^{y-1}b_i\ge b_{y}\),不等式在 \(y>1\) 时恒成立。所以策略就是能出强化牌就出,至少留下一张攻击牌

那么就可做了。考虑计算所有方案的伤害总和。分成两种情况讨论,选出的 \(m\) 张牌中强化牌小于 \(k-1\),那么出完强化牌就出攻击牌;反之,只出前 \(k-1\) 张强化,一张攻击。可以发现这些是可以预处理的。设 \(f_{i,j,0/1}\) 表示前 \(i\) 张强化牌选 \(j\) 张的乘积和(必选第 \(i\) 张/不必选第 \(i\) 张),\(g_{i,j,0/1}\) 表示前 \(i\) 张攻击牌选 \(j\) 张的和的和(必选第 \(i\) 张/不必选第 \(i\) 张)。转移易得。

第一种情况。考虑枚举强化牌数 \(i\) 和最后一张攻击牌位置 \(j\),贡献是 \(f_{n,i,1}\times g_{j,k-i,0}\times C(n-j,m-k)\)

第二种情况。考虑枚举最后一张强化牌位置 \(i\) 和唯一一张攻击牌位置 \(j\),贡献是 \(f_{i,k-1,0}\times b_j\times C(2\times n-i-j,m-k)\)

复杂度 \(O(n^2)\)

#include <bits/stdc++.h>
#define pii std::pair<int, int>
#define fi first
#define se second
#define pb push_back

using i64 = long long;
using ull = unsigned long long;
const i64 iinf = 0x3f3f3f3f, linf = 0x3f3f3f3f3f3f3f3f;
const int N = 3e3 + 10, mod = 998244353;
i64 n, m, k, ans;
i64 a[N], b[N];
i64 f[N][N][2], g[N][N][2];
i64 qpow(i64 a, i64 b, i64 m) {
	i64 ret = 1;
	while(b) {
		if(b & 1) ret = ret * a % m;
		a = a * a % m;
		b >>= 1;
	}
	return ret;
}
struct BIN {
	i64 fac[N], inv[N];
	void init(int n) {
		fac[0] = 1;
		for(int i = 1; i <= n; i++) fac[i] = fac[i - 1] * i % mod;
		inv[n] = qpow(fac[n], mod - 2, mod);
		for(int i = n - 1; i >= 0; i--) inv[i] = inv[i + 1] * (i + 1) % mod;
	}
	i64 C(i64 n, i64 m) {
		if(n < m) return 0;
		return fac[n] * inv[m] % mod * inv[n - m] % mod;
	}
	i64 C2(i64 n, i64 m) {
		if(n < m) return 0;	
		return fac[n] * qpow(fac[m], mod - 2, mod) % mod * qpow(fac[n - m], mod - 2, mod) % mod;
	}
	i64 lucas(i64 n, i64 m){
		if(!m) return 1;
		return C2(n % mod, m % mod) * lucas(n / mod, m / mod) % mod;
	}
} comb;
void solve() {
	std::cin >> n >> m >> k;
	for(int i = 1; i <= n; i++) {
		std::cin >> a[i];
	}
	for(int i = 1; i <= n; i++) {
		std::cin >> b[i];
	}
	std::sort(a + 1, a + n + 1, std::greater<int>());
	std::sort(b + 1, b + n + 1, std::greater<int>());
	f[0][0][0] = f[0][0][1] = 1;
	for(int i = 1; i <= n; i++) {
		f[i][0][1] = 1;
		for(int j = 1; j <= i; j++) {
			f[i][j][0] = a[i] * f[i - 1][j - 1][1] % mod;
			f[i][j][1] = (f[i][j][0] + f[i - 1][j][1]) % mod;
			g[i][j][0] = (b[i] * comb.C(i - 1, j - 1) % mod + g[i - 1][j - 1][1]) % mod;
			g[i][j][1] = (g[i][j][0] + g[i - 1][j][1]) % mod;
		}
	}
	for(int i = 0; i < k - 1; i++) {
		for(int j = 1; j <= n; j++) {
			ans = (ans + f[n][i][1] * g[j][k - i][0] % mod * comb.C(n - j, m - k) % mod) % mod;
		}
	}
	for(int i = 0; i <= n; i++) {
		for(int j = 1; j <= n; j++) {
			ans = (ans + f[i][k - 1][0] * b[j] % mod * comb.C(2 * n - i - j, m - k) % mod) % mod;
		}
	}
	for(int i = 1; i <= n; i++) for(int j = 1; j <= n; j++) f[i][j][0] = f[i][j][1] = g[i][j][0] = g[i][j][1] = 0;
	std::cout << ans << "\n";
	ans = 0;
}

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    comb.init(N - 10);

    int t;
    std::cin >> t;

	while(t--) solve();

	return 0;
}
posted @ 2024-05-23 13:45  Fire_Raku  阅读(5)  评论(0编辑  收藏  举报