●BZOJ 2839 集合计数

题链:

http://www.lydsy.com/JudgeOnline/problem.php?id=2839

题解:

容斥原理

真的是神题!!!

定义 f[k] 表示交集大小至少为 k时的方案数
怎么求出这个数组呢?
考虑先确定 k个元素(有C(N,k)种方法),那么还剩下 N-k个元素,
这剩下的 N-k个元素可以得到 2^(N-k)个集合,
然后每个集合可以选或不选,(但不能一个都不选),可以得到 2^(2^(N-k))-1 种选法,
每种选法里面的每个集合都加上那以及固定的 k个元素,
可以发现这所有的选法的交集大小都至少为 k。
所以 f[k]=C(N,k)*(22^(N-k)-1)

但是 f[k]还包含了交集为 k+1,k+2,k+3的方法,要怎么减去才能得到交集恰好为 k的方案数呢?
先看看这样一个结论:
不难发现 每种交集恰好为 k+1的方案都在 f[k]中被计算了 C(k+1,k)次。
怎么理解呢?
每种交集恰好为 k+1的方案(记这种方案为 A)的那 k+1个交集元素,
在计算 f[k] 时都可以从中选出 k个来固定,然后得到方案 A,即 A 方案在f[k]中被重复得到了 C(k+1,k)次。
所以要把重复的减去,容斥系数如下(当前要计算交集大小恰好为 k 的方案数):
f[k]        :1
f[k+1]    :-C(k+1,k)
f[k+2]    :+C(k+2,k)
诶,这个 f[k+2]的系数是怎么得到的呢?
首先 每种交集为 k+2 的方案在 f[k]中重复得了 C(k+2,k)次,所以 -C(k+2,k)
但是因为 f[k+1]的系数为 -C(k+1,k),
虽然我们只想减去 f[k]里重复了C(k+1,k)次的交集大小为 k+1 的方案数,
但是无奈再看看定义,f[k+1]里面还包含了交集大小为 k+2,k+3...的方案
所以在给 f[k+1]乘上系数 -C(k+1,k)时,也把 f[k+1]里面的每种交集大小为 k+2的方案减去了 C(k+1,k)次,
同时每种交集大小为 k+2的方案又在 f[k+1]里面的出现了 C(k+2,k+1)次
所以此时要加上因为 f[k+1]*-C(k+1,k)而减去了的交集大小为k+2的方案数,
+C(k+1,k)*C(k+2,k+1)
所以把两个结合起来: -C(k+2,k)+C(k+1,k)*C(k+2,k+1) ,化简即可得到 +C(k+2,k)
类似的可以得到 f[k+3],f[k+4]...的系数 :
f[k]        :1
f[k+1]    :-C(k+1,k)
f[k+2]    :-C(k+2,k)+C(k+1,k)*C(k+2,k+1) = +C(k+2,k)
f[k+3]    :-C(k+3,k)+C(k+1,k)*C(k+3,k+1)-C(k+2,k)*C(k+3,k+2) = -C(k+3,k)
f[k+4]    :-C(k+4,k)+C(k+1,k)*C(k+4,k+1)-C(k+2,k)*C(k+4,k+2)+C(k+3,k)*C(k+3,k+4) = +C(k+4,k)
......
总的式子为   n
                   ∑ (-1)^(i-k)*C(i,k)*f[i]
                  i=k

然后逮着式子计算就好了。
另外要先线性预处理出阶乘和阶乘逆元以及f数组,便于使用。

代码:

#include<cstdio>
#include<cstring>
#include<iostream>
#define _ %mod
#define MAXN 1005000
#define filein(x) freopen(#x".in","r",stdin);
#define fileout(x) freopen(#x".out","w",stdout);
using namespace std;
const int mod=1000000007;
int w[MAXN],f[MAXN],fac[MAXN],inv[MAXN];
int N,K,ANS;
int C(int n,int m){//n中选 m个 
	return ((1ll*fac[n]*inv[m])_*inv[n-m])_;
}
int pow(int a,int b){
	int now=1;
	while(b){
		if(b&1) now=(1ll*now*a)_;
		a=(1ll*a*a)_;
		b>>=1;
	}
	return now;
}
void pre(int n){
	w[0]=2;	fac[0]=1; 
	for(int i=1;i<=n;i++)
		fac[i]=(1ll*fac[i-1]*i)_;
	inv[n]=pow(fac[n],mod-2);
	for(int i=n-1;i>=0;i--)
		inv[i]=(1ll*inv[i+1]*(i+1))_;
	for(int i=1;i<=n;i++)
		w[i]=(1ll*w[i-1]*w[i-1])_;
	for(int i=0;i<=n;i++)
		f[i]=(1ll*(w[n-i]-1)*C(n,i))_;
}
int main()
{
	scanf("%d%d",&N,&K);
	pre(N);
	for(int i=K;i<=N;i++)
		ANS=(1ll*C(i,K)*f[i]*((i-K)&1?-1:1)+ANS)_;
	ANS=(ANS+mod)_;
	printf("%d",ANS);
	return 0;
}

posted @ 2017-12-12 19:21  *ZJ  阅读(227)  评论(0编辑  收藏  举报