LOJ#2538. 「PKUWC2018」Slay the Spire DP+组合

自己想+切掉的,开心.    

推一推就发现如果要权值最大,就尽量使用那个乘法.    

然后就分两种情况讨论一下:

1. 乘法用到 k-1 次.  

2. 乘法用不到 k-1 次.   

code: 

#include <bits/stdc++.h>  
#define ll long long   
#define N 3006  
#define mod 998244353 
#define setIO(s) freopen(s".in","r",stdin) 
using namespace std;   
int f[N][N],g[N][N],inv[N],fac[N]; 
int w1[N],w2[N],F[N],s1[N],sk[N];     
bool cmp(int a,int b) { return a>b; }   
inline int qpow(int x,int y) 
{
	int tmp=1; 
	for(;y;y>>=1,x=(ll)x*x%mod)   
		if(y&1) tmp=(ll)tmp*x%mod;  
	return tmp; 
}
inline int INV(int x) { return qpow(x,mod-2); }     
inline int C(int x,int y) { 
	if(x<0||y<0||x<y) return 0; 
	return (ll)fac[x]*inv[y]%mod*inv[x-y]%mod; 
}                  
void init()     
{ 
	fac[0]=inv[0]=1;   
	for(int i=1;i<N;++i) fac[i]=(ll)fac[i-1]*i%mod,inv[i]=INV(fac[i]);   
}       
void solve() 
{      
	int n,m,k;  
	scanf("%d%d%d",&n,&m,&k);           
	for(int i=1;i<=n;++i)  scanf("%d",&w2[i]);   
	for(int i=1;i<=n;++i)  scanf("%d",&w1[i]);   
	sort(w2+1,w2+1+n,cmp);   
	sort(w1+1,w1+1+n,cmp);        
	memset(s1,0,sizeof(s1)); 
	memset(sk,0,sizeof(sk));    
	memset(F,0,sizeof(F));  
	g[0][0]=1;       
	// 考虑对 f 动一动手        
	for(int i=1;i<=n;++i) 
	{
		g[i][0]=1;   
		for(int j=1;j<=i;++j) 
		{
			f[i][j]=(ll)(f[i-1][j]+f[i-1][j-1]+(ll)w1[i]*C(i-1,j-1)%mod)%mod;                         
			g[i][j]=(ll)(g[i-1][j]+(ll)g[i-1][j-1]*w2[i]%mod)%mod;                      
		}                
		for(int j=1;j<=m;++j) 
		{ 
			if(n-i>=m-k) (F[j]+=(ll)(f[i-1][j-1]+(ll)w1[i]*C(i-1,j-1)%mod)%mod*C(n-i,m-k)%mod)%=mod;         
			(s1[j]+=(ll)w1[i]*C(n-i,j-1)%mod)%=mod;       
			if(j>=k-1&&k>1)    
				(sk[j]+=(ll)g[i-1][k-2]*w2[i]%mod*C(n-i,j-k+1)%mod)%=mod;           
		}                
	}          
	int ans=0;     
	for(int i=1;i<=m;++i) 
	{
		int x=i,y=m-i;                
		if(k==1) { (ans+=(ll)s1[x]*C(n,y)%mod)%=mod; continue;  }
   		if(y>=k-1)           
			(ans+=(ll)s1[x]*sk[y]%mod)%=mod;   
		else 
			(ans+=(ll)g[n][y]*F[k-y]%mod)%=mod;    
	}    
	printf("%d\n",ans);     
}
int main()
{ 
	// setIO("input");                  
	int T; 
	scanf("%d",&T);  
	init();  
	while(T--) solve();  
	return 0; 
}

  

 

posted @ 2020-05-29 17:15  EM-LGH  阅读(155)  评论(0编辑  收藏  举报