二项式反演
简述
二项式反演常用于解决整体方案较容易解决,而加上限制之后的子问题较难解决的问题。从名字上就可以判断,它和组合数相关。
举个例子,恰用\(k\)种颜色填满\(n\)个格子,要求相邻格子异色。一眼看上去很让人头疼啊,但是如果没有一定要用完\(k\)种颜色的限制的话就简单多了。
设\(g_i\)表示恰用i种颜色的方案数,则有\(k(k-1)^{n-1}=\sum_{i=0}^k \binom{k}{i} g_i\)。设\(f_k=k(k-1)^{n-1}\)。现在我们希望从\(f\)倒推到\(g\),二项式反演就可以做到这一点。
先直接给出结论,\(g_k=\sum_{i=0}^k (-1)^{k-i}\binom{k}{i} f_i\)。证明下面会有,同时也会给出二项式反演的其它形式。
初步证明
在求方案的时候,我们常用容斥来逆推一个问题的答案。事实上,二项式反演就和容斥的原理相关,容斥本身就是一种反演(子集反演)。下用容斥原理来对二项式反演进行证明。
设共有\(n\)个集合,\(A_i\)表示第\(i\)个集合,则所有集合的并集元素个数可表示成以下形式:
容斥定理证明
设\(A_i^c\)表示\(A_i\)的补集,\(S\)表示全集,则:
也即:不被任一集合所包含的所有元素为全集中的元素减去至少被一个集合包含的元素。
又由于补集的补集是其本身,所以\(|A_1\cap A_2\cap\cdots\cap A_n|=|S|-\sum_{i=1}^n (-1)^{i-1}\sum|A_1^c\cap A_2^c\cap\cdots\cap A_i^c|\)同样成立。
若多个交集的大小只和集合个数有关,设\(f_n\)表示\(n\)个补集的交集大小,\(g_n\)表示\(n\)个原集的交集大小。特别的,空集的补集为\(S\),设\(f_0=g_0=|S|\)。则由上述两个式子可得到:
显然两个式子可互相推导,则二项式反演的初始形式得证:
变式
(1)
证明
设\(h_i=(-1)^i g_i\),由初始形式:
证毕。
(1')
证明
将\(g_i=\sum_{j=m}^i (-1)^{i-j}\binom{i}{j}f_j\)代入\(\sum_{i=m}^n\binom{n}{i}g_i--(*)\)得:
调换求和顺序,这一步的目的是为了将\(f_j\)提出来。由于对于\(\forall j\in[m,n]\),\(f_j\)会被\(i\in[j,n]\)计算一次,且乘上的系数为\((-1)^{i-j}\binom{n}{i}\binom{i}{j}\),可得:
再进行进一步的转化,现在感觉前边没有组合数有点不对称,想分一个组合数给\(j\)那边使得形式比较统一。\(\binom{n}{i}\binom{i}{j}\)的意义是从\(n\)中选出\(i\)个,再从\(i\)中选出\(j\)个。不妨转化为先在\(n\)个中选出\(j\)个,代表最终被选的;再从剩下的\(n-j\)个中选出\(i-j\)个,代表陪跑的倒霉蛋。或直接代数可验证,\(\binom{n}{i}\binom{i}{j}=\binom{n}{j}\binom{n-j}{i-j}\)。则:
设\(t=i-j\),有:
全程都是由等号连接,(1`)的充分必要性当然就得到了证明。它的适用性更广,因为没有用上\(i\)从0开始的性质。
(2)
证明
将\(g_i=\sum_{j=i}^m (-1)^{j-i}\binom{j}{i}f_i\)代入\(\sum_{i=n}^{m}\binom{i}{n}g_i--(**)\),得:
证毕。
(2')
证明
设\(h_i=x^i g_i,H_i=x^i f_i\),由(2)得:
证毕。
相关练习
[CQOI2015] 多项式
其实是高精度的毒瘤题……不想调高精度的写出式子就可以跑了,或者拿个模板用一下
正确式子:
代码
[click]
#include <cstdio>
#include <cctype>
typedef long long ll;
const int p=1000000007;
const int maxn=1000000+10;
int fac[maxn],inv[maxn];
int n,k;
int read()
{
int res=0;
char ch=getchar();
while(!isdigit(ch))
ch=getchar();
while(isdigit(ch))
res=res*10+ch-'0',ch=getchar();
return res;
}
int mod(ll x)
{
if (x<0)
x%=p,x+=p;
else if (x>=p)
x%=p;
return x;
}
int power(int a,int n)
{
int res=1;
while(n)
{
if (n&1)
res=mod((ll)res*a);
a=mod((ll)a*a);
n>>=1;
}
return res;
}
void prework()
{
fac[0]=1;
for (int i=1;i<=n;i++)
fac[i]=mod((ll)fac[i-1]*i);
inv[n]=power(fac[n], p-2);
for (int i=n-1;i>=1;i--)
inv[i]=mod((ll)inv[i+1]*(i+1));
inv[0]=1;
}
int C(int n,int m) {return mod((ll)mod((ll)fac[n]*inv[m])*inv[n-m]);}
int main()
{
n=read(),k=read();
prework();
int ans=0,mul=2;
for (int i=n;i>=k;i--)
{
if ((i-k)&1)
ans-=mod((ll)mod((ll)C(i, k)*C(n, i))*(mul-1));
else
ans+=mod((ll)mod((ll)C(i, k)*C(n, i))*(mul-1));
ans=mod(ans);
mul=mod((ll)mul*mul);
}
printf("%d\n",ans);
return 0;
}
[BZOJ2839] 集合计数
题解
[click]
交集元素个数为k,可考虑先确定交集的元素,再去从剩下的\(2^{2^{n-k}}-1\)个子集中选出不相交的集合,这样就可以确定选出的集合交集元素为k的所有方案。但是,选出的子集的数量不确定,怎么选出不相交的子集也是一个很大的问题。这个时候就考虑用二项式反演从整体推到局部来简化问题了。
设\(f_i\)表示取出的集合交集元素恰好为i个的方案数,\(g_i=\binom{n}{i}(2^{2^{n-i}}-1)\)表示交集元素至少存在某i个的选择方案数之和。需要注意的是,\(g_i\)并不是交集元素至少为i个的方案总数,会有算重的情况出现,这是因为在任意选剩下的集合的时候会使交集的范围变大。从另外一种角度来说,也可以理解为,在方案进行加和时,由于选择的角度不同而导致了重复。所以应该把\(g_i\)中的\(i\)视为选出了\(i\)个代表元素用于保底,加强这个方案的限制。由于每个取出的集合交集元素恰好为i个的方案都会被\(g_j,i\le j\le n\)算\(\binom{i}{j}\)次(相当于从交集中的i个元素中选j个作为代表),则:
当熟悉了以后,反演就是一个有点套路的东西,关键在于想明白怎么不算重不算漏然后进行转化。同时也不能对二项式反演过于执着,这只是一个考虑的方向。
同时从这里可以注意到,当只确定至少k个的方案比较困难时,可以考虑钦定某k个一定存在,以它为第二层标准进行计数。由于元素间地位相等,则可以通过在恰好的方案中进行选择的方法,即乘上\(\binom{n}{k}\)并求和来推导。
代码
[click]
#include <cstdio>
#include <cctype>
typedef long long ll;
const int p=1000000007;
const int maxn=1000000+10;
int fac[maxn],inv[maxn];
int n,k;
int read()
{
int res=0;
char ch=getchar();
while(!isdigit(ch))
ch=getchar();
while(isdigit(ch))
res=res*10+ch-'0',ch=getchar();
return res;
}
int mod(ll x)
{
if (x<0)
x%=p,x+=p;
else if (x>=p)
x%=p;
return x;
}
int power(int a,int n)
{
int res=1;
while(n)
{
if (n&1)
res=mod((ll)res*a);
a=mod((ll)a*a);
n>>=1;
}
return res;
}
void prework()
{
fac[0]=1;
for (int i=1;i<=n;i++)
fac[i]=mod((ll)fac[i-1]*i);
inv[n]=power(fac[n], p-2);
for (int i=n-1;i>=1;i--)
inv[i]=mod((ll)inv[i+1]*(i+1));
inv[0]=1;
}
int C(int n,int m) {return mod((ll)mod((ll)fac[n]*inv[m])*inv[n-m]);}
int main()
{
n=read(),k=read();
prework();
int ans=0,mul=2;
for (int i=n;i>=k;i--)
{
if ((i-k)&1)
ans-=mod((ll)mod((ll)C(i, k)*C(n, i))*(mul-1));
else
ans+=mod((ll)mod((ll)C(i, k)*C(n, i))*(mul-1));
ans=mod(ans);
mul=mod((ll)mul*mul);
}
printf("%d\n",ans);
return 0;
}