[PKUWC2018]Slay the Spire
我竟然能做出九老师的组合计数题,尽管这题很呆
我们先考虑一个简单的问题,如果给定你要选出来的\(m\)张卡牌,如何做到攻击伤害最高
非常显然,因为保证了强化牌上的数字大于\(1\),所以我们优先选择那些强化牌,毕竟最小的一张只能翻两倍的强化牌都要比选择攻击牌好;选完强化牌之后剩下的攻击牌自然是越大越好
当然了,我们也不可能选出\(k\)张强化牌来,这样什么伤害都造不成,所以如果强化牌数量大于\(k\),我们就选择前\(k-1\)大的强化牌,配上最大的攻击牌
我们发现给定的强化牌和攻击牌都是无序的,我们先排序再说
之后我们就可以大力\(dp\)了
设\(dp_{i,j}\)表示前\(i\)张强化牌里选出\(j\)张所有强化牌乘积的和,\(f_{i,j}\)表示前\(i\)张强化牌里选择了\(j\)张且第\(i\)张一定被选择的乘积和
显然有转移
\[dp_{i,j}=dp_{i-1,j}+dp_{i-1,j-1}\times a_i
\]
\[f_{i,j}=dp_{i-1,j-1}\times a_i
\]
攻击牌这边我们设\(h_{i,j}\)表示前\(i\)张攻击牌里选择了\(j\)张的所有方案的和,\(g_{i,j}\)表示强迫选择第\(i\)张
也有转移
\[h_{i,j}=h_{i-1,j}+h_{i-1,j-1}+C_{i-1}^{j-1}\times b_i
\]
\[g_{i,j}=h_{i-1,j-1}+C_{i-1}^{j-1}\times b_i
\]
现在我们考虑枚举选择了多少张强化牌
设选择了\(i\)张强化牌,自然也就需要选择\(m-i\)张攻击牌
如果\(i<k\),那么这些强化牌都是要直接用的,选择的\(m-i\)张攻击牌里能被用于打出的\(k\)张的也就只有\(k-i\)张,而剩下的\(m-i-(k-i)=m-k\)张牌我们在后面随便选一下
于是答案就是
\[\sum_{j=0}^nf_{j,i}\times \sum_{j=0}^ng_{j,k-i}C_{n-j}^{m-k}
\]
如果\(i>=k\),那么我们也只能选择\(k-1\)张强化牌,剩下的多选出来的\(i-k+1\)张我们还是从后面的位置里选出来,实际上需要的攻击牌也就只有\(1\)张,需要额外多选的攻击牌也就只有\(m-i-1\)张
答案就是
\[\sum_{j=0}^nf_{j,k-1} C_{n-j}^{i-k+1}\times \sum_{j=0}^ng_{j,1}C_{n-j}^{m-i-1}
\]
代码
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define re register
#define LL long long
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
inline int read() {
char c=getchar();int x=0;while(c<'0'||x>'9') c=getchar();
while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
const int mod=998244353;
const int maxn=3005;
int c[maxn][maxn];
int dp[maxn][maxn],f[maxn][maxn];
int g[maxn][maxn],h[maxn][maxn];
int T,n,m,k;
int a[maxn],b[maxn];
inline int C(int n,int m) {if(m>n) return 0;return c[n][m];}
inline int cmp(int A,int B) {return A>B;}
int main() {
T=read();int M=maxn-5;
for(re int i=0;i<=M;i++) c[i][0]=c[i][i]=1;
for(re int i=1;i<=M;i++)
for(re int j=1;j<i;j++)
c[i][j]=(c[i-1][j-1]+c[i-1][j])%mod;
while(T--) {
n=read(),m=read(),k=read();
for(re int i=1;i<=n;i++) a[i]=read();
for(re int i=1;i<=n;i++) b[i]=read();
std::sort(a+1,a+n+1,cmp);std::sort(b+1,b+n+1,cmp);
dp[0][0]=1;
for(re int i=1;i<=n;i++)
for(re int j=0;j<=i;j++) {
dp[i][j]=dp[i-1][j];
if(j) dp[i][j]=(dp[i][j]+1ll*dp[i-1][j-1]*a[i]%mod)%mod;
}
f[0][0]=1;
for(re int i=1;i<=n;i++)
for(re int j=1;j<=i;j++)
f[i][j]=1ll*dp[i-1][j-1]*a[i]%mod;
h[0][0]=0;
for(re int i=1;i<=n;i++)
for(re int j=0;j<=i;j++) {
h[i][j]=h[i-1][j];
if(j) h[i][j]=(h[i][j]+1ll*C(i-1,j-1)*b[i]%mod+h[i-1][j-1])%mod;
}
for(re int i=1;i<=n;i++)
for(re int j=1;j<=i;j++)
g[i][j]=(1ll*C(i-1,j-1)*b[i]%mod+h[i-1][j-1])%mod;
int ans=0;
for(re int i=0;i<m;i++) {
int now=0;
for(re int j=0;j<=n;j++) {
if(i<=k-1) now=(now+f[j][i])%mod;
else now=(now+1ll*f[j][k-1]*C(n-j,i-k+1)%mod)%mod;
}
int tot=0,res=0;
if(i<=k-1) res=k-i;else res=1;
for(re int j=0;j<=n;j++)
tot=(tot+1ll*g[j][res]*C(n-j,m-i-res)%mod)%mod;
ans=(ans+1ll*now*tot%mod)%mod;
}
printf("%d\n",ans);
}
return 0;
}