题解 洛谷 P5299 【[PKUWC2018]Slay the Spire】
先考虑有 \(m\) 张牌,打 \(k\) 张的最优策略。发现强化牌的效果至少是翻倍,所以最优策略一定是在至少打出一张攻击牌的前提下,尽可能的多打强化牌,强化牌数量不够时,再由大到小打攻击牌。
设 \(F_{i,j}\) 为选出 \(i\) 张强化牌,打出 \(j\) 张的效果, \(G_{i,j}\) 为选出 \(i\) 张攻击牌,打出 \(j\) 张的效果,根据最优策略,枚举选出强化牌的个数,得答案为:
\[\large\begin{aligned}
ans&=\sum_{i=\max(m-n,0)}^{\min(n,m)} val_i \\
\\
val_i &=
\begin{cases}
F_{i,i}G_{m-i,k-i} &i < k \\
\\
F_{i,k-1}G_{m-i,1} & i \geqslant k \\
\end{cases}
\end{aligned}
\]
发现直接求 \(F,G\) 不好求,因为最优策略一定是先打权值大的牌,所以先将两种牌从大到小排序,设 \(f_{i,j}\) 为考虑了排序后的前 \(i\) 张强化牌且第 \(i\) 张必选,选了 \(j\) 张牌的效果,\(g_{i,j}\) 为考虑了排序后的前 \(i\) 张攻击牌且第 \(i\) 张必选,选了 \(j\) 张牌的效果,得:
\[\large\begin{aligned}
f_{i,j}&=v_i\sum_{k=j-1}^{i-1} f_{k,j-1} \\
g_{i,j}&=\binom{i-1}{j-1}v_i + \sum_{k=j-1}^{i-1} g_{k,j-1}
\end{aligned}
\]
乘上组合数的原因是考虑转移过来的方案数,这里可以前缀和优化为 \(O(n^2)\)。然后考虑 \(f,g\) 对 \(F,G\) 的贡献,得:
\[\large\begin{aligned}
F_{i,j} = \sum_{k=j}^n \binom{n-k}{i-j} f_{k,j} \\
G_{i,j} = \sum_{k=j}^n \binom{n-k}{i-j} g_{k,j} \\
\end{aligned}
\]
直接计算 \(F,G\) 是 \(O(n^3)\) 的,但通过 \(f,g\) 计算,需要哪个算哪个,复杂度就是 \(O(n^2)\) 了。
#include<bits/stdc++.h>
#define maxn 3010
#define all 3000
#define p 998244353
using namespace std;
typedef long long ll;
template<typename T> inline void read(T &x)
{
x=0;char c=getchar();bool flag=false;
while(!isdigit(c)){if(c=='-')flag=true;c=getchar();}
while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
if(flag)x=-x;
}
int T,n,m,k;
ll v1[maxn],v2[maxn],s1[maxn],s2[maxn],f[maxn][maxn],g[maxn][maxn],fac[maxn],ifac[maxn];
bool cmp(const ll &a,const ll &b)
{
return a>b;
}
ll inv(ll x)
{
ll y=p-2,v=1;
while(y)
{
if(y&1) v=v*x%p;
x=x*x%p,y>>=1;
}
return v;
}
void init()
{
fac[0]=ifac[0]=1;
for(int i=1;i<=all;++i) fac[i]=fac[i-1]*i%p;
ifac[all]=inv(fac[all]);
for(int i=all-1;i;--i) ifac[i]=ifac[i+1]*(i+1)%p;
}
ll C(int n,int m)
{
if(n<m) return 0;
return fac[n]*ifac[n-m]%p*ifac[m]%p;
}
ll F(int i,int j)
{
ll v=0;
for(int k=j;k<=n;++k)
v=(v+C(n-k,i-j)*f[k][j]%p)%p;
return v;
}
ll G(int i,int j)
{
ll v=0;
for(int k=j;k<=n;++k)
v=(v+C(n-k,i-j)*g[k][j]%p)%p;
return v;
}
ll solve()
{
read(n),read(m),read(k);
for(int i=1;i<=n;++i) read(v1[i]);
for(int i=1;i<=n;++i) read(v2[i]);
sort(v1+1,v1+n+1,cmp),sort(v2+1,v2+n+1,cmp);
f[0][0]=1,memset(s1,0,sizeof(s1)),memset(s2,0,sizeof(s2));
for(int i=1;i<=n;++i)
{
for(int j=1;j<=i;++j)
{
s1[j-1]=(s1[j-1]+f[i-1][j-1])%p,f[i][j]=v1[i]*s1[j-1]%p;
s2[j-1]=(s2[j-1]+g[i-1][j-1])%p,g[i][j]=(C(i-1,j-1)*v2[i]%p+s2[j-1])%p;
}
}
ll ans=0;
for(int i=max(m-n,0);i<=min(n,m);++i)
{
if(i<k) ans=(ans+F(i,i)*G(m-i,k-i)%p)%p;
else ans=(ans+F(i,k-1)*G(m-i,1)%p)%p;
}
return ans;
}
int main()
{
init(),read(T);
while(T--) printf("%lld\n",solve());
return 0;
}