题解-PKUWC2018 Slay the Spire
Problem
Solution
在考场上当然要学会写暴力,考虑如果手上已经有了\(a\)张攻击牌和\(b\)张强化牌:
- 首先强化牌会在攻击牌之前用(废话),其次要将两种牌分别从大往小打,即排个序先(也是废话)
- 要尽量打强化牌,最后再打一张攻击牌(由于每张强化牌至少乘二,所以打一张强化牌一定不比多打一张攻击牌差)
由于\(n\leq 3000\),预估复杂度为\(O(n^2)\),所以应该可以枚举两种牌的数量
设两个状态,\(F[i][j]\)表示选取\(i\)张强化牌,打出\(j\)张的强化效果,\(G[i][j]\)表示攻击牌,由于有\(k\)的限制,所以分类讨论一下:
- 若当前升级牌数量超过\(k\),则需要打出\(k-1\)张升级牌,再出一张最大值的攻击牌
- 若当前升级牌数量少于\(k\),则需要打出所有升级牌,再从攻击牌中找到最大的一些牌
然后答案就是
\[Ans=\sum_{i=0}^{k-1}F[i][i]\cdot G[m-i][k-i]+\sum_{i=k}^{\min(n,m)}F(i,k-1)\cdot G(m-i,1)
\]
接下来考虑如何求出\(F,G\)数组,由于需要在所有选牌情况里需要贪心,所以一种解决方案是构造辅助数组
考虑对牌从大到小排序,\(f[i][j],g[i][j]\)表示前\(i\)张牌中选取了\(j\)张牌,且必选自己的情况中的和
转移方程(可以利用前缀和做到\(O(n^2)\)转移):
\(f[i][j]=w_i\cdot \sum_{l=j-1}^{i-1}f[l][j-1] \\ g[i][j]=\binom {i-1}{j-1}w_i+\sum_{l=j-1}^{i-1}g[l][j-1]\)
然后就可以枚举第\(j\)张牌是第几张来将\(f\rightarrow F,g\rightarrow G\):
\(F[i][j]=\sum_{l=j}^n\binom {n-l}{i-j}f[l][j] \\ G[i][j]=\sum_{l=j}^n\binom {n-l}{i-j}g[l][j]\)
由于求一个\(F,G\)是\(O(n)\)的,所以求出所有的\(F,G\)是\(O(n^3)\)的,但由于我们计算答案的时候只需要用到\(O(n)\)个\(F,G\),所以每次暴力从\(f,g\)统计即可
Code
#include <bits/stdc++.h>
using namespace std;
inline void read(int&x){
char c11=getchar();x=0;while(!isdigit(c11))c11=getchar();
while(isdigit(c11))x=x*10+c11-'0',c11=getchar();
}
const int N=3010,p = 998244353;
int fac[N],inv[N],f[N][N],g[N][N];
int w1[N],w2[N],t1[N],t2[N];
int n,m,k;
inline int qpow(int A,int B){
int res(1);while(B){
if(B&1)res=1ll*res*A%p;
A=1ll*A*A%p;B>>=1;
}return res;
}
inline int c(int nn,int mm){return 1ll*fac[nn]*inv[mm]%p*inv[nn-mm]%p;}
inline int cmp(const int&A,const int&B) {return A>B;}
template <typename _Tp> inline int qm(_Tp x){return x<p?x:x-p;}
void prework(){
fac[0]=inv[0]=1;
for(int i=1;i<N;++i)fac[i]=1ll*fac[i-1]*i%p;
inv[N-1]=qpow(fac[N-1],p-2);
for(int i=N-2;i;--i)inv[i]=1ll*inv[i+1]*(i+1)%p;
}
inline int F(int i,int j){
static int res;res=0;
for(int l=n-max(i-j,0);l>=j;--l)
res=qm(res+1ll*c(n-l,i-j)*f[l][j]%p);
return res;
}
inline int G(int i,int j){
static int res;res=0;
for(int l=n-max(i-j,0);l>=j;--l)
res=qm(res+1ll*c(n-l,i-j)*g[l][j]%p);
return res;
}
void init();void work();
int main(){
prework();int T;read(T);
while(T--)init(),work();
return 0;
}
void work(){
int ans=0;
for(int i=min(n,m);~i&&i>=m-n;--i)
if(i<k)ans=qm(ans+1ll*F(i,i)*G(m-i,k-i)%p);
else ans=qm(ans+1ll*F(i,k-1)*G(m-i,1)%p);
printf("%d\n",ans);
}
void init(){
read(n),read(m),read(k);f[0][0]=1;
memset(t1,0,sizeof(t1));memset(t2,0,sizeof(t2));
for(int i=1;i<=n;++i)read(w1[i]);sort(w1+1,w1+n+1,cmp);
for(int i=1;i<=n;++i)read(w2[i]);sort(w2+1,w2+n+1,cmp);
for(int i=1;i<=n;++i)f[i][1]=w1[i],g[i][1]=w2[i];
for(int i=1;i<=n;++i)
for(int j=2;j<=i;++j){
t1[j-1]=qm(t1[j-1]+f[i-1][j-1]);
f[i][j]=1ll*w1[i]*t1[j-1]%p;
t2[j-1]=qm(t2[j-1]+g[i-1][j-1]);
g[i][j]=qm(1ll*c(i-1,j-1)*w2[i]%p+t2[j-1]);
}
}