LOJ2538 PKUWC2018 Slay the Spire DP
不想放题面了,咕咕咕咕咕
这个期望明明是用来吓人的,其实要算的就是所有方案的最多伤害的和。
首先可以知道的是,能出强化牌就出强化牌(当然最后要留一张攻击牌出出去),且数字尽量大
所以说在强化牌数量$< K$时会打出所有强化牌和剩下的最大的攻击牌,而强化牌数量$\geq K$的时候则会打出$K-1$张强化牌和$1$张攻击牌,且它们的数字都是最大的
我们不妨计算每一种最优打出的方案存在在多少种抽取方案中。
设$f_{i,j}$表示使用$i$张强化牌,其中数值最小的牌是第$j$张时的方案的强化数值之和,$g_{i,j}$表示使用$i$张攻击牌,其中数值最小的牌是第$j$张时的方案的攻击数值之和,简单的前缀和优化DP就可以完成。处理完之后,所有抽取了$x$张强化牌和$y$张攻击牌,选择$i$张强化牌和$j$张攻击牌的方案的总伤害和就是
$$\sum\limits_k \sum\limits_l f_{i,k} \times g_{j,l} \times C_{k-1}^{x-i} \times C_{l-1}^{y-j}$$
两个理解:
①$f$和$g$中间用乘号是因为$f$中间的每一个方案和$g$中的每一个方案都可以对应产生一种抽取方式
②后面的两个组合数的意思是:对于强化牌,已经使用了$i$张,最小的是$k$,那么剩下的$x-i$张需要在剩余的$k-1$中抽取,方案数就是组合数,后面同理
那么我们枚举强化牌抽了多少张,就可以直接计算答案。
1 #include<bits/stdc++.h> 2 //This code is written by Itst 3 using namespace std; 4 5 inline int read(){ 6 int a = 0; 7 bool f = 0; 8 char c = getchar(); 9 while(c != EOF && !isdigit(c)){ 10 if(c == '-') 11 f = 1; 12 c = getchar(); 13 } 14 while(c != EOF && isdigit(c)){ 15 a = (a << 3) + (a << 1) + (c ^ '0'); 16 c = getchar(); 17 } 18 return f ? -a : a; 19 } 20 21 const int MOD = 998244353 , MAXN = 3000; 22 long long dp[MAXN + 10][MAXN + 10] , f[MAXN + 10][MAXN + 10] , g[MAXN + 10][MAXN + 10] , num[2][MAXN + 10] , jc[MAXN + 10] , ny[MAXN + 10] , N , M , K , ans; 23 24 bool cmp(int a , int b){ 25 return a > b; 26 } 27 28 inline int poww(long long a , int b){ 29 int times = 1; 30 while(b){ 31 if(b & 1) 32 times = times * a % MOD; 33 a = a * a % MOD; 34 b >>= 1; 35 } 36 return times; 37 } 38 39 signed main(){ 40 #ifndef ONLINE_JUDGE 41 freopen("2538.in" , "r" , stdin); 42 //freopen("2538.out" , "w" , stdout); 43 #endif 44 jc[0] = 1; 45 for(long long i = 1 ; i <= MAXN ; ++i) 46 jc[i] = jc[i - 1] * i % MOD; 47 ny[MAXN] = poww(jc[MAXN] , MOD - 2); 48 for(long long i = MAXN - 1 ; i >= 0 ; --i) 49 ny[i] = ny[i + 1] * (i + 1) % MOD; 50 for(int i = 0 ; i <= MAXN ; ++i) 51 dp[0][i] = 1; 52 for(int i = 1 ; i <= MAXN ; ++i) 53 for(int j = 1 ; j <= MAXN ; ++j) 54 dp[i][j] = (dp[i - 1][j - 1] + dp[i][j - 1]) % MOD; 55 for(int i = 0 ; i <= MAXN ; ++i) 56 f[0][i] = 1; 57 58 for(int T = read() ; T ; --T){ 59 N = read(); 60 M = read(); 61 K = read(); 62 for(int i = 1 ; i <= N ; ++i) 63 num[0][i] = read(); 64 sort(num[0] + 1 , num[0] + N + 1 , cmp); 65 for(int i = 1 ; i <= N ; ++i) 66 num[1][i] = read(); 67 sort(num[1] + 1 , num[1] + N + 1 , cmp); 68 ans = 0; 69 70 for(int i = 1 ; i <= N ; ++i) 71 for(int j = 1 ; j <= N ; ++j){ 72 f[i][j] = (1ll * f[i - 1][j - 1] * num[0][j] + f[i][j - 1]) % MOD; 73 g[i][j] = (g[i - 1][j - 1] + 1ll * (dp[i][j] - dp[i][j - 1] + MOD) * num[1][j] + g[i][j - 1]) % MOD; 74 } 75 76 for(int i = 0 ; i < K ; ++i){ 77 int sum1 = f[i][N] , sum2 = 0; 78 for(int j = K - i ; N - j >= M - K ; ++j) 79 sum2 = (sum2 + 1ll * (g[K - i][j] - g[K - i][j - 1] + MOD) * jc[N - j] % MOD * ny[M - K] % MOD * ny[N - j - (M - K)]) % MOD; 80 ans = (ans + 1ll * sum1 * sum2) % MOD; 81 } 82 83 for(int i = K ; i < M ; ++i){ 84 int sum1 = 0 , sum2 = 0; 85 for(int j = K - 1 ; N - j >= i - (K - 1) ; ++j) 86 sum1 = (sum1 + 1ll * (f[K - 1][j] - f[K - 1][j - 1] + MOD) * jc[N - j] % MOD * ny[i - (K - 1)] % MOD * ny[N - j - (i - (K - 1))]) % MOD; 87 for(int j = 1 ; N - j >= M - i - 1 ; ++j) 88 sum2 = (sum2 + 1ll * (g[1][j] - g[1][j - 1] + MOD) * jc[N - j] % MOD * ny[M - i - 1] % MOD * ny[N - j - (M - i - 1)]) % MOD; 89 ans = (ans + 1ll * sum1 * sum2) % MOD; 90 } 91 cout << ans << endl; 92 } 93 return 0; 94 }