【bzoj 3622】已经没有什么好害怕的了
看到这个数据范围就发现我们需要一个\(O(n^2)\)的做法了,那大概率是\(dp\)了
看到恰好\(k\)个我们就知道这基本是个容斥了
首先解方程发现我们需要使得\(a>b\)的恰好有\(\frac{n+k}{2}\)组
如果不整除我们直接输出\(0\)就好了
之后开始使用套路,直接广义容斥
\[ans=\sum_{i=k}^n(-1)^{i-k}\binom{i}{k}g_i
\]
\(g_i\)表配出至少\(i\)对\(a>b\)的情况
显然我们现在需要一个\(dp\)来算一下\(g\)
首先发现两个数组是没有顺序的,所以先习惯性排个序
设\(dp_{i,j}\)表示从\(a\)数组的前\(i\)个数中,已经配出\(j\)对\(a>b\)的方案数
边界\(dp_{0,0}=1\)
我们排序的作用这个时候就体现出来了,我们设\(d_i\)表示满足\(b_j<a_i\)的最大的\(j\)
由于\(a,b\)两个数组都是有序的,我们知道\(d_i\)肯定是单调不降的
于是有这样的方程
\[dp_{i,j}=dp_{i-1,j}+max(d_i-(j-1),0)dp_{i-1,j-1}
\]
就是考虑对于第\(i\)个数满足\(a>b\)的只有\(d_i\)个,减去和前\(i-1\)个匹配的\(j-1\)个,剩下的我们随便找出一个来匹配就好了
之后\(g_i=dp_{n,i}(n-i)!\),就是让没有满足\(a<b\)的那些随便匹配一下就好
代码
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#define re register
#define LL long long
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
#define pt putchar(1),puts("")
const int maxn=2e3+5;
const int mod=1e9+9;
inline int read() {
char c=getchar();int x=0;while(c<'0'||c>'9') c=getchar();
while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
int n,k;
int dp[maxn][maxn];
int a[maxn],b[maxn];
int fac[maxn],inv[maxn];
inline LL ksm(LL a,int b) {
LL S=1;
while(b) {if(b&1) S=S*a%mod;b>>=1;a=a*a%mod;}
return S;
}
inline int C(int n,int m) {
if(m>n) return 0;
return 1ll*fac[n]*inv[n-m]%mod*inv[m]%mod;
}
int main() {
n=read();k=read();
for(re int i=1;i<=n;i++) a[i]=read();
for(re int i=1;i<=n;i++) b[i]=read();
if((n+k)&1) {puts("0");return 0;}
std::sort(a+1,a+n+1),std::sort(b+1,b+n+1);
fac[0]=1;
for(re int i=1;i<=n;i++) fac[i]=(1ll*i*fac[i-1])%mod;
inv[n]=ksm(fac[n],mod-2);
for(re int i=n-1;i>=0;--i) inv[i]=(1ll*(i+1)*inv[i+1])%mod;
dp[0][0]=1;
for(re int i=1;i<=n;i++) {
int cnt=0;
for(re int j=1;j<=n;j++)
cnt+=(a[i]>b[j]);
for(re int j=0;j<=i;j++)
dp[i][j]=dp[i-1][j];
for(re int j=1;j<=i;j++)
dp[i][j]=(dp[i][j]+1ll*dp[i-1][j-1]*max(cnt-j+1,0)%mod)%mod;
}
k=(n+k)/2;
LL ans=0;
for(re int i=k;i<=n;i++)
if((i-k)&1) ans=(ans-1ll*C(i,k)*dp[n][i]%mod*fac[n-i]%mod+mod)%mod;
else ans=(ans+1ll*C(i,k)*dp[n][i]%mod*fac[n-i]%mod)%mod;
printf("%d\n",(int)ans);
return 0;
}