题解 洛谷 P5299 【[PKUWC2018]Slay the Spire】

先考虑有 \(m\) 张牌,打 \(k\) 张的最优策略。发现强化牌的效果至少是翻倍,所以最优策略一定是在至少打出一张攻击牌的前提下,尽可能的多打强化牌,强化牌数量不够时,再由大到小打攻击牌。

\(F_{i,j}\) 为选出 \(i\) 张强化牌,打出 \(j\) 张的效果, \(G_{i,j}\) 为选出 \(i\) 张攻击牌,打出 \(j\) 张的效果,根据最优策略,枚举选出强化牌的个数,得答案为:

\[\large\begin{aligned} ans&=\sum_{i=\max(m-n,0)}^{\min(n,m)} val_i \\ \\ val_i &= \begin{cases} F_{i,i}G_{m-i,k-i} &i < k \\ \\ F_{i,k-1}G_{m-i,1} & i \geqslant k \\ \end{cases} \end{aligned} \]

发现直接求 \(F,G\) 不好求,因为最优策略一定是先打权值大的牌,所以先将两种牌从大到小排序,设 \(f_{i,j}\) 为考虑了排序后的前 \(i\) 张强化牌且第 \(i\) 张必选,选了 \(j\) 张牌的效果,\(g_{i,j}\) 为考虑了排序后的前 \(i\) 张攻击牌且第 \(i\) 张必选,选了 \(j\) 张牌的效果,得:

\[\large\begin{aligned} f_{i,j}&=v_i\sum_{k=j-1}^{i-1} f_{k,j-1} \\ g_{i,j}&=\binom{i-1}{j-1}v_i + \sum_{k=j-1}^{i-1} g_{k,j-1} \end{aligned} \]

乘上组合数的原因是考虑转移过来的方案数,这里可以前缀和优化为 \(O(n^2)\)。然后考虑 \(f,g\)\(F,G\) 的贡献,得:

\[\large\begin{aligned} F_{i,j} = \sum_{k=j}^n \binom{n-k}{i-j} f_{k,j} \\ G_{i,j} = \sum_{k=j}^n \binom{n-k}{i-j} g_{k,j} \\ \end{aligned} \]

直接计算 \(F,G\)\(O(n^3)\) 的,但通过 \(f,g\) 计算,需要哪个算哪个,复杂度就是 \(O(n^2)\) 了。

#include<bits/stdc++.h>
#define maxn 3010
#define all 3000
#define p 998244353
using namespace std;
typedef long long ll;
template<typename T> inline void read(T &x)
{
    x=0;char c=getchar();bool flag=false;
    while(!isdigit(c)){if(c=='-')flag=true;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    if(flag)x=-x;
}
int T,n,m,k;
ll v1[maxn],v2[maxn],s1[maxn],s2[maxn],f[maxn][maxn],g[maxn][maxn],fac[maxn],ifac[maxn];
bool cmp(const ll &a,const ll &b)
{
    return a>b;
}
ll inv(ll x)
{
    ll y=p-2,v=1;
    while(y)
    {
        if(y&1) v=v*x%p;
        x=x*x%p,y>>=1;
    }
    return v;
}
void init()
{
    fac[0]=ifac[0]=1;
    for(int i=1;i<=all;++i) fac[i]=fac[i-1]*i%p;
    ifac[all]=inv(fac[all]);
    for(int i=all-1;i;--i) ifac[i]=ifac[i+1]*(i+1)%p;
}
ll C(int n,int m)
{
    if(n<m) return 0;
    return fac[n]*ifac[n-m]%p*ifac[m]%p;
}
ll F(int i,int j)
{
    ll v=0;
    for(int k=j;k<=n;++k)
        v=(v+C(n-k,i-j)*f[k][j]%p)%p;
    return v;
}
ll G(int i,int j)
{
    ll v=0;
    for(int k=j;k<=n;++k)
        v=(v+C(n-k,i-j)*g[k][j]%p)%p;
    return v;
}
ll solve()
{
    read(n),read(m),read(k);
    for(int i=1;i<=n;++i) read(v1[i]);
    for(int i=1;i<=n;++i) read(v2[i]);
    sort(v1+1,v1+n+1,cmp),sort(v2+1,v2+n+1,cmp);
    f[0][0]=1,memset(s1,0,sizeof(s1)),memset(s2,0,sizeof(s2));
    for(int i=1;i<=n;++i)
    {
        for(int j=1;j<=i;++j)
        {
            s1[j-1]=(s1[j-1]+f[i-1][j-1])%p,f[i][j]=v1[i]*s1[j-1]%p;
            s2[j-1]=(s2[j-1]+g[i-1][j-1])%p,g[i][j]=(C(i-1,j-1)*v2[i]%p+s2[j-1])%p;
        }
    }
    ll ans=0;
    for(int i=max(m-n,0);i<=min(n,m);++i)
    {
        if(i<k) ans=(ans+F(i,i)*G(m-i,k-i)%p)%p;
        else ans=(ans+F(i,k-1)*G(m-i,1)%p)%p;
    }
    return ans;
}
int main()
{
    init(),read(T);
    while(T--) printf("%lld\n",solve());
    return 0;
}
posted @ 2020-08-11 17:40  lhm_liu  阅读(153)  评论(0编辑  收藏  举报