题目
为了方便表示,令 “糖果” 为 \(a\),“药片” 为 \(b\)。
解法
首先 \(n\equiv k\pmod 2\),不然是无解的。不过好像没有这种数据(
由于没有重复的数字,我们可以确定 \(a>b\) 的组数为 \(s=\frac{n+k}{2}\)。
然后有一个比较神仙的 \(\mathtt{dp}\):将 \(a,b\) 分别进行排序,令 \(dp_{i,j}\) 为前 \(i\) 个 \(a\) 中 至少 有 \(j\) 组 \(a>b\) 的方案数。注意,我们 钦定 的 \(a<b\) 的匹配种类没有被计算。有:
\[dp_{i,j}=dp_{i-1,j}+dp_{i-1,j-1}\times \max\{0,lst_i-(j-1)\}
\]
其中 \(lst_i\) 为小于 \(a_i\) 的 \(b\) 的个数。后面乘上的那一坨就是使 \(a_i>b\) 的还未匹配的 \(b\) 的个数。
另外,\(\max\) 实际上可以去掉,因为当 \(lst_i-(j-1)<0\) 时,\(dp_{i-1,j-1}\) 必为 \(0\)。
令 \(f_i=dp_{n,i}\times (n-i)!\),这样 \(f_i\) 就计算上了 钦定 的 \(a<b\) 的匹配种类。
令 \(g_i\) 为所有 \(a\) 中 恰好 有 \(i\) 组 \(a>b\) 的方案数。发现对于 \(i\ge m\),\(f_m\) 中计算了 \(\text{C}(i,m)\) 次 \(g_i\)。考虑套用二项式反演,有:
\[g_s=\sum_{i=s}^n (-1)^{i-s}\cdot \text{C}(i,s)\cdot f_i
\]
代码
#include <cstdio>
#define print(x,y) write(x),putchar(y)
template <class T>
inline T read(const T sample) {
T x=0; char s; bool f=0;
while((s=getchar())>'9' or s<'0')
f|=(s=='-');
while(s>='0' and s<='9')
x=(x<<1)+(x<<3)+(s^48),
s=getchar();
return f?-x:x;
}
template <class T>
inline void write(const T x) {
if(x<0) {
putchar('-'),write(-x);
return;
}
if(x>9) write(x/10);
putchar(x%10^48);
}
#include <iostream>
#include <algorithm>
using namespace std;
const int mod=1e9+9,maxn=2005;
int n,k,a[maxn],b[maxn],lst[maxn];
int fac[maxn],ifac[maxn];
int dp[maxn][maxn];
int inv(int x,int y=mod-2) {
int r=1;
while(y) {
if(y&1) r=1ll*r*x%mod;
x=1ll*x*x%mod; y>>=1;
}
return r;
}
void init() {
fac[0]=1;
for(int i=1;i<=maxn-5;++i)
fac[i]=1ll*fac[i-1]*i%mod;
ifac[maxn-5]=inv(fac[maxn-5]);
for(int i=maxn-6;i>=0;--i)
ifac[i]=1ll*ifac[i+1]*(i+1)%mod;
}
int C(int n,int m) {
if(n<m) return 0;
return 1ll*fac[n]*ifac[m]%mod*ifac[n-m]%mod;
}
int main() {
n=read(9),k=read(9);
init();
for(int i=1;i<=n;++i)
a[i]=read(9);
for(int i=1;i<=n;++i)
b[i]=read(9);
sort(a+1,a+n+1);
sort(b+1,b+n+1);
int p=0;
for(int i=1;i<=n;++i) {
while(p+1<=n and b[p+1]<a[i])
++p;
lst[i]=p;
}
dp[0][0]=1;
for(int i=1;i<=n;++i) {
dp[i][0]=1;
for(int j=1;j<=n;++j)
dp[i][j]=(dp[i-1][j]+1ll*dp[i-1][j-1]*max(0,lst[i]-j+1)%mod)%mod;
}
int ans=0,tmp=1;
k=n+k>>1;
for(int i=n;i>=k;--i) {
if(i-k&1)
ans=(ans-1ll*C(i,k)*dp[n][i]%mod*tmp%mod+mod)%mod;
else
ans=(ans+1ll*C(i,k)*dp[n][i]%mod*tmp%mod)%mod;
tmp=1ll*tmp*(n-i+1)%mod;
}
print(ans,'\n');
return 0;
}