题解-洛谷P4859 已经没有什么好害怕的了
给定 \(n\) 和 \(k\),\(n\) 个糖果能量 \(a_i\) 和 \(n\) 个药片能量 \(b_i\),每个 \(a_i\) 和 \(b_i\) 互不相等。将糖果和药片一一对应,求 糖果能量大于药片 比 药片能量大于糖果 多 \(k\) 组的方案数。
数据范围:\(1\le n\le 2000\),\(0\le k\le n\)。
萌新初学二项式反演,这是第一道完全自己做出来的题,所以写篇题解庆祝并提升理解。
有 \(\frac{n+k}{2}\) 组糖果能量大于药片,\(\frac{n-k}{2}\) 组药片能量大于糖果。
如果 \(n+k\) 是奇数,直接答案为 \(0\) 特判掉。
\(f(i)\) 表示 \(i\) 组糖果能量大于药片,\(n-i\) 组药片能量大于糖果的方案数。
\(g(i)\) 表示 \(i\) 组糖果能量大于药片,\(n-i\) 组随意的方案数。
二项式反演必然有 \(f(i)\) 和 \(g(i)\),往往前者表示 \(i\) 个符合条件 \(a\) 剩下符合另条件 \(b\),后者表示 \(i\) 个符合条件 \(a\) 剩下随意。
先考虑 \(g(i)\) 怎么独立地求,蒟蒻想到了 \(\tt dp\)。
将 \(a_i\) 和 \(b_i\) 排序,现在 \(a_i<a_{i+1}\),\(b_i<b_{i+1}\)。
比如 \(b_i<a_1<b_{i+1}\),\(b_j<a_2<b_{j+1}(i<j)\)。
所以 \(a_1\) 可以对应 \(b_1\sim b_i\),\(a_2\) 可以对应 \(b_1\sim b_j\)。
因为 \(a_1\) 对于的 \(b_x\) 满足 \(x<i<j\),所以必然占了一个 \(a_2\) 可以对应的位。
所以有 \(i(j-1)\) 种对应法。
设 \(F_{i,j}\) 表示看了 \(a_1\sim a_i\),对应了 \(j\) 组的方案数。
令 \(p(i)\) 表示 \(b_{p(i)}<a_i<b_{p(i)+1}\)。
同理,所以 \(F(0,0)=1\)。
二项式反演来了:
答案是 \(f(\frac{n+k}{2})\),带进去算就好了。
时间复杂度 \(\Theta(n^2)\),空间复杂度 \(\Theta(n^2)\)。
- 代码
下标从 \(0\) 开始的...巨佬们琢磨琢磨吧。\(\tt /kel\)。
#include <bits/stdc++.h>
using namespace std;
//Start
typedef long long ll;
typedef double db;
#define mp(a,b) make_pair(a,b)
#define x first
#define y second
#define be(a) a.begin()
#define en(a) a.end()
#define sz(a) int((a).size())
#define pb(a) push_back(a)
const int inf=0x3f3f3f3f;
const ll INF=0x3f3f3f3f3f3f3f3f;
//Data
const int mod=1e9+9;
//Main
int main(){
cin.tie(0);
int n,k; cin>>n>>k;
vector<int> a(n),b(n);
for(int&ai:a) cin>>ai;
for(int&bi:b) cin>>bi;
if((n-k)&1) return cout<<0<<'\n',0;
sort(be(a),en(a));
sort(be(b),en(b));
vector<vector<int>> f(n+1,vector<int>(n+1,0));
f[0][0]=1;
for(int i=0,p=-1;i<n;i++){
while(p+1<n&&b[p+1]<a[i]) p++;
for(int j=0;j<n+1;j++) f[i+1][j]=f[i][j];
for(int j=0;j<n;j++) (f[i+1][j+1]+=(ll)f[i][j]*(p-j+1)%mod)%=mod;
}
for(int j=n,s=1;j>=0;j--) f[n][j]=(ll)f[n][j]*s%mod,s=(ll)s*(n-j+1)%mod;
vector<vector<int>> c(n+1,vector<int>(n+1,0));
for(int i=0;i<n+1;i++){
c[i][0]=c[i][i]=1;
for(int j=1;j<i;j++) c[i][j]=(c[i-1][j-1]+c[i-1][j])%mod;
}
int ans=0,t=(n+k)>>1;
for(int i=t;i<n+1;i++){
int sum=(ll)f[n][i]*c[i][t]%mod;
if((i-t)&1) (ans+=-sum+mod)%=mod;
else (ans+=sum)%=mod;
}
cout<<ans<<'\n';
return 0;
}
祝大家学习愉快!