LOJ #2538. 「PKUWC 2018」Slay the Spire (期望dp)
Update on 1.5
学了 zhou888 的写法,真是又短又快。
并且空间是 \(O(n)\) 的,速度十分优秀。
题意
LOJ #2538. 「PKUWC 2018」Slay the Spire
题解
首先我们考虑拿到一副牌如何打是最优的,不难发现是将强化牌从大到小能打就打,最后再从大到小打攻击牌 。
为什么呢 ?
证明(简单说明) : 如果不是这样 , 那么我们就是有强化牌没有用 , 且攻击牌超过两张 .
我们考虑把最小的那张攻击牌拿出来 , 然后放入一张强化牌 .
\(\because~w_i \ge 2\) 且 最小那张攻击牌的攻击力 \(a_{\min}\) 不大于所有攻击牌的总和 \(a_{sum}\) 的一半
\(\therefore\) 修改后造成的伤害绝对不比前面少 . 得证.
我们只要枚举上下分别用了多少张牌 , 假设 加强 用了 \(a\) 张 , 攻击 用了 \(b\) 张 . \((a + b = m)\)
那么我们只要分两种情况考虑了 :
- \(a < k:\) 那么我们加强可以全用完 , 攻击用前 \(k - a\) 大的 ;
- \(a \ge k:\) 这个加强用前 \(k - 1\) 大的 , 攻击用一张最大的 .
令 \(f_i\) 为选 \(i\) 张强化牌能得到的最优倍率之和,显然强化牌我们从大到小取是最优的。
假设当前取到第 \(j\) 张牌。
那么有如下转移:
上面那种情况是还能用强化牌,下面已经不能加新的强化牌了,所以不乘上倍率。(注意 \(f_0 = 1\) )
同样我们设 \(g_i\) 为选 \(i\) 张攻击牌能得到的最优攻击之和,此处我们需要从小到大排序。
有如下转移:
考虑这张牌我们先放进来,不难发现对于所有 \(i \le m - (k - 1)\) 也就是只能打一张的,我们只统计了这张打的贡献。
如果能打很多张,这样转移的话就能保证我们尽量取的是靠后的那些元素。
最后答案直接就是 \(\displaystyle \sum_{i = 0}^{m} f_{i} g_{m-i}\) 。
总结
需要啥就设啥,想清楚情况再 \(dp\) 。
代码
#include <bits/stdc++.h>
#define For(i, l, r) for (register int i = (l), i##end = (int)(r); i <= i##end; ++i)
#define Fordown(i, r, l) for (register int i = (r), i##end = (int)(l); i >= i##end; --i)
#define Rep(i, r) for (register int i = (0), i##end = (int)(r); i < i##end; ++i)
#define Set(a, v) memset(a, v, sizeof(a))
#define Cpy(a, b) memcpy(a, b, sizeof(a))
#define debug(x) cout << #x << ": " << (x) << endl
using namespace std;
template<typename T> inline bool chkmin(T &a, T b) { return b < a ? a = b, 1 : 0; }
template<typename T> inline bool chkmax(T &a, T b) { return b > a ? a = b, 1 : 0; }
inline int read() {
int x(0), sgn(1); char ch(getchar());
for (; !isdigit(ch); ch = getchar()) if (ch == '-') sgn = -1;
for (; isdigit(ch); ch = getchar()) x = (x * 10) + (ch ^ 48);
return x * sgn;
}
void File() {
#ifdef zjp_shadow
freopen ("2538.in", "r", stdin);
freopen ("2538.out", "w", stdout);
#endif
}
const int N = 3e3 + 1e2, Mod = 998244353;
inline int fpm(int x, int power) {
int res = 1;
for (; power; power >>= 1, x = 1ll * x * x % Mod)
if (power & 1) res = 1ll * res * x % Mod;
return res;
}
int fac[N], ifac[N];
void Math_Init(int maxn) {
fac[0] = ifac[0] = 1;
For (i, 1, maxn) fac[i] = 1ll * fac[i - 1] * i % Mod;
ifac[maxn] = fpm(fac[maxn], Mod - 2);
Fordown (i, maxn - 1, 1) ifac[i] = ifac[i + 1] * (i + 1ll) % Mod;
}
inline int Comb(int n, int m) {
if (n < 0 || m < 0 || n < m) return 0;
return 1ll * fac[n] * ifac[m] % Mod * ifac[n - m] % Mod;
}
int n, m, k, a[N], f[N], g[N];
int main () {
File();
Math_Init(3000);
for (int cases = read(); cases; -- cases) {
n = read(); m = read(); k = read();
For (i, 1, n) a[i] = read();
For (i, 1, max(n, m)) f[i] = g[i] = 0;
sort(a + 1, a + n + 1, greater<int>());
f[0] = 1;
For (i, 1, n) Fordown (j, min(i, m), 0)
if (j <= k - 1) f[j] = (f[j] + 1ll * f[j - 1] * a[i]) % Mod;
else f[j] = (f[j] + f[j - 1]) % Mod;
For (i, 1, n) a[i] = read();
sort(a + 1, a + n + 1);
For (i, 1, n) Fordown (j, min(i, m), 0)
if (j <= m - (k - 1))
g[j] = (g[j] + 1ll * Comb(i - 1, j - 1) * a[i]) % Mod;
else
g[j] = (g[j] + g[j - 1] + 1ll * Comb(i - 1, j - 1) * a[i]) % Mod;
int ans = 0;
For (i, 0, m)
ans = (ans + 1ll * f[i] * g[m - i]) % Mod;
printf ("%d\n", ans);
}
return 0;
}