LOJ #2538. 「PKUWC 2018」Slay the Spire (期望dp)

Update on 1.5

学了 zhou888 的写法,真是又短又快。

并且空间是 \(O(n)\) 的,速度十分优秀。

题意

LOJ #2538. 「PKUWC 2018」Slay the Spire

题解

首先我们考虑拿到一副牌如何打是最优的,不难发现是将强化牌从大到小能打就打,最后再从大到小打攻击牌 。

为什么呢 ?

证明(简单说明) : 如果不是这样 , 那么我们就是有强化牌没有用 , 且攻击牌超过两张 .

我们考虑把最小的那张攻击牌拿出来 , 然后放入一张强化牌 .

\(\because~w_i \ge 2\) 且 最小那张攻击牌的攻击力 \(a_{\min}\) 不大于所有攻击牌的总和 \(a_{sum}\) 的一半

\(\therefore\) 修改后造成的伤害绝对不比前面少 . 得证.

我们只要枚举上下分别用了多少张牌 , 假设 加强 用了 \(a\) 张 , 攻击 用了 \(b\) 张 . \((a + b = m)\)

那么我们只要分两种情况考虑了 :

  1. \(a < k:\) 那么我们加强可以全用完 , 攻击用前 \(k - a\) 大的 ;
  2. \(a \ge k:\) 这个加强用前 \(k - 1\) 大的 , 攻击用一张最大的 .

\(f_i\) 为选 \(i\) 张强化牌能得到的最优倍率之和,显然强化牌我们从大到小取是最优的。

假设当前取到第 \(j\) 张牌。

那么有如下转移:

\[f_i = \begin{cases} (f_i + f_{i - 1}) \times a[j] &i < k\\ f_i + f_{i - 1} &i \ge k \end{cases} \]

上面那种情况是还能用强化牌,下面已经不能加新的强化牌了,所以不乘上倍率。(注意 \(f_0 = 1\)

同样我们设 \(g_i\) 为选 \(i\) 张攻击牌能得到的最优攻击之和,此处我们需要从小到大排序。

有如下转移:

\[g_i = g_i + \displaystyle {j - 1 \choose i - 1} \times a[j] + \begin{cases} 0 &\le m - (k - 1)\\ g_{i - 1} & >m - (k - 1) \end{cases} \]

考虑这张牌我们先放进来,不难发现对于所有 \(i \le m - (k - 1)\) 也就是只能打一张的,我们只统计了这张打的贡献。

如果能打很多张,这样转移的话就能保证我们尽量取的是靠后的那些元素。

最后答案直接就是 \(\displaystyle \sum_{i = 0}^{m} f_{i} g_{m-i}\)

总结

需要啥就设啥,想清楚情况再 \(dp\)

代码

#include <bits/stdc++.h>

#define For(i, l, r) for (register int i = (l), i##end = (int)(r); i <= i##end; ++i)
#define Fordown(i, r, l) for (register int i = (r), i##end = (int)(l); i >= i##end; --i)
#define Rep(i, r) for (register int i = (0), i##end = (int)(r); i < i##end; ++i)
#define Set(a, v) memset(a, v, sizeof(a))
#define Cpy(a, b) memcpy(a, b, sizeof(a))
#define debug(x) cout << #x << ": " << (x) << endl

using namespace std;

template<typename T> inline bool chkmin(T &a, T b) { return b < a ? a = b, 1 : 0; }
template<typename T> inline bool chkmax(T &a, T b) { return b > a ? a = b, 1 : 0; }

inline int read() {
    int x(0), sgn(1); char ch(getchar());
    for (; !isdigit(ch); ch = getchar()) if (ch == '-') sgn = -1;
    for (; isdigit(ch); ch = getchar()) x = (x * 10) + (ch ^ 48);
    return x * sgn;
}

void File() {
#ifdef zjp_shadow
	freopen ("2538.in", "r", stdin);
	freopen ("2538.out", "w", stdout);
#endif
}

const int N = 3e3 + 1e2, Mod = 998244353;

inline int fpm(int x, int power) {
	int res = 1;
	for (; power; power >>= 1, x = 1ll * x * x % Mod)
		if (power & 1) res = 1ll * res * x % Mod;
	return res;
}

int fac[N], ifac[N];

void Math_Init(int maxn) {
	fac[0] = ifac[0] = 1;
	For (i, 1, maxn) fac[i] = 1ll * fac[i - 1] * i % Mod;
	ifac[maxn] = fpm(fac[maxn], Mod - 2);
	Fordown (i, maxn - 1, 1) ifac[i] = ifac[i + 1] * (i + 1ll) % Mod;
}

inline int Comb(int n, int m) {
	if (n < 0 || m < 0 || n < m) return 0;
	return 1ll * fac[n] * ifac[m] % Mod * ifac[n - m] % Mod;
}

int n, m, k, a[N], f[N], g[N];

int main () {

	File();

	Math_Init(3000);

	for (int cases = read(); cases; -- cases) {

		n = read(); m = read(); k = read();
		For (i, 1, n) a[i] = read();

		For (i, 1, max(n, m)) f[i] = g[i] = 0;

		sort(a + 1, a + n + 1, greater<int>()); 

		f[0] = 1;
		For (i, 1, n) Fordown (j, min(i, m), 0)
			if (j <= k - 1) f[j] = (f[j] + 1ll * f[j - 1] * a[i]) % Mod;
			else f[j] = (f[j] + f[j - 1]) % Mod;

		For (i, 1, n) a[i] = read();
		sort(a + 1, a + n + 1);
		For (i, 1, n) Fordown (j, min(i, m), 0)
			if (j <= m - (k - 1))
				g[j] = (g[j] + 1ll * Comb(i - 1, j - 1) * a[i]) % Mod;
			else 
				g[j] = (g[j] + g[j - 1] + 1ll * Comb(i - 1, j - 1) * a[i]) % Mod;

		int ans = 0;
		For (i, 0, m)
			ans = (ans + 1ll * f[i] * g[m - i]) % Mod;
		printf ("%d\n", ans);

	}

    return 0;

}
posted @ 2018-05-25 21:24  zjp_shadow  阅读(838)  评论(0编辑  收藏  举报