【bzoj2839】【集合计数】容斥原理+线性求阶乘逆元小技巧
(上不了p站我要死了,侵权度娘背锅)
Description
一个有N个元素的集合有2^N个不同子集(包含空集),现在要在这2^N个集合中取出若干集合(至少一个),使得
它们的交集的元素个数为K,求取法的方案数,答案模1000000007。(是质数喔~)
Input
一行两个整数N,K
Output
一行为答案。
Sample Input
3 2
Sample Output
6
HINT
【样例说明】
假设原集合为{A,B,C}
则满足条件的方案为:{AB,ABC},{AC,ABC},{BC,ABC},{AB},{AC},{BC}
【数据说明】
对于100%的数据,1≤N≤1000000;0≤K≤N;
自己的数论果然还是太差了,看来大佬的博客才知道该怎么做
当 题目所求难以直接求得,而扩大一些的范围方便求 时,就思考用容斥定理。对于这道题,交集恰好为k的不好求,但是交集元素包含k的方便求,所以就考虑用 包含k的方案数-包含k+1的方案数+包含k+2的方案数…… 来求出 恰好包含k 的方案数。于是就将大问题拆成了相似的小问题。
交集元素个数包含i的方案数是很好求的。首先要保证一定有i个元素,所以就是“从n个元素中选i个元素”,即C(n,i)。接下来就是选择一些集合包含选出来的i个元素的方案数,该问题等价于选一些集合不包含这i个元素的方案数。即 除开这i个元素,剩下的n-i个元素组成的集合有2^(n-i)个,每个集合有选与不选两种状态,所以就是2^(2^(n-i))-1种方案。因为不能全都不选(2^(n-i)个集合中已包含空集),所以要-1。
但是却发现这样的做法连样例都过不了O^O,难道说这个方法错了吗?冷静下来仔细分析样例,按照这样的算法得出:
交集包含2个的:共9种
{{1,2}},{{1,2,3}},{{1,2},{1,2,3}}
{{1,3}},{{1,3,2}},{{1,3},{1,3,2}}
{{3,2}},{{3,2,1}},{{3,2},{3,2,1}}
交集包含3个的:共1种
{{1,2,3}}
我们发现,在“交集包含2的”的情况中,集合{1,2,3}出现了3次,包含的2个元素分别是我们选出来的{1,2},{2,3},{1,3}。所以我们算重了3次。化为一般情况就是 包含i个的 方案∗选择包含的是哪k个元素(i个元素中选k个元素),即为 C(i,k)∗C(n,i)∗(2^(2^(n-i)-1)。
加上容斥原理,令f[i]=2^(2^(n-i)-1,最终的公式为:
然后有一些缩短时间的小技巧。虽然题目的数据o(nlogn)是能过的,但是相比起来就太慢了一些,n再加一个0就完了。
而logn的复杂度主要是出在 阶乘求逆元 和 求f[i] 上。仔细观察性质,发现f[i]虽然是一个指数套指数的形式,但由于底数是2,所以可以相乘来递推(这样也就不用担心指数的快速幂的模数要取phi(mod)了)。而 阶乘的逆元是可以线性求解的(我以前一直都不知道qwq,又长见识了)。
因为 ( (n-1)! )^-1 = (n!)^-1 * n ,所以先求出阶乘,o(logn)求出 n! 的逆元,再倒着推回去。在本题总时间减少了近一半。
AC代码
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
#ifdef WIN32
#define RIN "%I64d"
#else
#define RIN "%lld"
#endif
template <typename T>inline void read(T &res){
T k=1,x=0;char ch=0;
while(ch<'0'||ch>'9'){if(ch=='-')k=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
res=k*x;
}
const int N=1000000+5;
const ll MOD=1e9+7;
ll n,k,jiec[N],niy[N];
ll f[N];
void exgcd(ll a,ll b,ll &x,ll &y){
if(b==0){
x=1,y=0;
return;
}
ll x0,y0;
exgcd(b,a%b,x0,y0);
x=y0;
y=x0-(a/b)*y0;
}
ll inverse(ll a){
ll x,y;
exgcd(a,MOD,x,y);
return (x%MOD+MOD)%MOD;
}
void init(){
jiec[0]=niy[0]=1;
/*for(int i=1;i<=n;i++){
jiec[i]=jiec[i-1]*i%MOD;
niy[i]=inverse(jiec[i]);
}*/
for(int i=1;i<=n;i++) jiec[i]=jiec[i-1]*i%MOD;
niy[n]=inverse(jiec[n]);
for(int i=n-1;i>=1;i--) niy[i]=niy[i+1]*(i+1)%MOD;
}
ll power(ll a,ll b,ll mod){
ll rt=1;
for(int i=b;i;i>>=1,a=(a*a)%mod)
if(i&1) rt=(rt*a)%mod;
return rt;
}
ll F(ll x){
/*ll tmp=power(2,x,MOD-1);
return (power(2,tmp,MOD)-1+MOD)%MOD;*/
return (f[x]-1+MOD)%MOD;
}
ll C(ll a,ll b){
if(b>a) return 0;
return jiec[a]*niy[b]%MOD*niy[a-b]%MOD;
}
int main(){
read(n),read(k);
init();
f[0]=2;
for(int i=1;i<=n-k;i++) f[i]=f[i-1]*f[i-1]%MOD;
ll ans=0;
for(ll i=k;i<=n;i++)
ans=(ans+C(n,i)*C(i,k)%MOD*F(n-i)%MOD*((i-k&1)?-1:1)%MOD+MOD)%MOD;
cout<<ans<<endl;
return 0;
}