[PKUWC2018]Slay the Spire
Solution
实际上是求最大伤害总和。
有一个只要有眼睛就能看出来的结论:能出强化牌就出强化牌,最后剩一张出攻击牌,当然如果强化牌不满\(k\)个就把强化牌出完剩下出攻击牌。因为强化牌都是大于等于1的正整数,所以带来的效果是至少让伤害翻一倍那么显然尽量出强化牌。(然鹅我可能真的没眼睛,看了20分钟才看到\(w_i\)都是正整数)
由乘法分配律,强化牌和攻击牌的贡献可以分开来计算。设\(F_{i,j}\)表示所有的 \(i\)张强化牌里选\(j\)张打出的情况的贡献 之和,\(G_{i,j}\)表示所有 \(i\)张攻击牌选\(j\)张打出的情况的贡献 之和。
那么有:
\[ans=sum\begin{cases}
F_{i,i}\times G_{m-i,k-i},i<k\\
F_{i,k-1}\times G_{m-i,1},i\geq k
\end{cases}
\]
但是F和G不太好算,考虑设\(dp\)数组辅助。
首先可以把牌从大到小排序,因为我肯定是在能选的牌堆中选最大的那些出掉。
设\(f_{i,j}\)表示选i张强化牌打出,其中最小的那张是第j张的贡献。\(g_{i,j}\)表示选i张攻击牌打出,其中最小的是第j张的贡献。简单dp一下,利用前缀和优化就是\(\text O(n^2)\)的:
\[\begin{align}
&f_{i,j}=a_j\times \sum\limits_{i-1\leq k\leq j-1}f_{i-1,k}\\
&g_{i,j}=b_j\times \binom{j-1}{i-1} + \sum\limits_{i-1\leq k\leq j-1}g_{i-1,k}
\end{align}
\]
求出这个之后F,G就是f,g乘上一个组合数再算算的事情:
\[F_{x,y}=\sum\limits_{i\geq y}f_{y,i}\times \binom{n-i}{x-y}\\
G_{x,y}=\sum\limits_{i\geq y}g_{y,i}\times \binom{n-i}{x-y}
\]
F,G对总共\(O(n)\)个,所以计算答案也是\(O(n^2)\)级别的。
像这种题不太会做的原因是想不到可以给状态加限制,例如此题加一个强制第j个是打出的牌中最小的那个就很好转移了。
另外就是函数C(n,m),就是组合数的函数,一定要记得特判 \(n<m\) 的情况!!!
Code
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
inline int read(){//be careful for long long!
register int x=0,f=1;register char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
while(isdigit(ch)){x=x*10+(ch^'0');ch=getchar();}
return f?x:-x;
}
const int N=3e3+10,mod=998244353;
int n,m,k,a[N],b[N],tmp[N],fac[N],ifc[N];
int f[N][N],g[N][N];
inline int power(int base,int n){int ans=1;for(;n;n>>=1,base=1ll*base*base%mod)if(n&1)ans=1ll*ans*base%mod;return ans;}
inline int C(int n,int m){if(n<m)return 0;return 1ll*fac[n]*ifc[m%mod]%mod*ifc[n-m]%mod;}
inline bool cmp(const int &x,const int &y){return x>y;}
inline int F(int x,int y){
int ans=0;
for(int i=y;i<=n;++i)ans=(ans+1ll*f[y][i]*C(n-i,x-y)%mod)%mod;
return ans;
}
inline int G(int x,int y){
int ans=0;
for(int i=y;i<=n;++i)ans=(ans+1ll*g[y][i]*C(n-i,x-y)%mod)%mod;
return ans;
}
int main(){
for(int i=fac[0]=1;i<N;++i)fac[i]=1ll*fac[i-1]*i%mod;
ifc[N-1]=power(fac[N-1],mod-2);
for(int i=N-2;~i;--i)ifc[i]=1ll*ifc[i+1]*(i+1)%mod;
int T=read();
while(T--){
n=read(),m=read(),k=read();
for(int i=1;i<=n;++i)a[i]=read();
for(int i=1;i<=n;++i)b[i]=read();
sort(a+1,a+n+1,cmp),sort(b+1,b+n+1,cmp);
f[0][0]=1;
for(int i=1;i<=n;++i)f[1][i]=a[i],g[1][i]=b[i];
for(int i=2;i<=n;++i){
int sf=f[i-1][i-1],sg=g[i-1][i-1];
for(int j=i;j<=n;++j){
f[i][j]=1ll*a[j]*sf%mod;
g[i][j]=(1ll*b[j]*C(j-1,i-1)%mod+sg)%mod;
sf=(sf+f[i-1][j])%mod;sg=(sg+g[i-1][j])%mod;
}
}
int ans=0;
for(int i=max(m-n,0);i<=min(n,m-1);++i){
if(i<k)ans=(ans+1ll*F(i,i)*G(m-i,k-i)%mod)%mod;
else ans=(ans+1ll*F(i,k-1)*G(m-i,1)%mod)%mod;
}
printf("%d\n",ans);
}
return 0;
}