容斥原理
容斥原理
一.容斥原理:
设\(S_1,S_2,S_3...,S_n\)为有限集合,\(\mid S_i\mid\)表示集合大小,则:
\(\mid\bigcup\limits_{i=1}^{n}S_i\mid=\sum\limits_{i=1}^{n}\mid S_i\mid-\sum\limits_{1\leq i<j\leq n}\mid S_i\cap S_j\mid+\sum\limits_{1\leq i<j<k\leq n}\mid S_i\cap S_j\cap S_k\mid+...+(-1)^{n+1}\mid S_1\cap S_2\cap...\cap S_k\mid\)
(终于把公式码完了)
似乎很简单.
二.应用
多重集的组合数.
有多重集\(S=\{ n_1a_1,n_2a_2,\dots,n_ka_k\}\)
求从中取出\(r\)个元素组成多重集(不考虑顺序)的方案数\(N\).
我们令\(n=\sum n_i\),分情况讨论:
\((1)\qquad r>n,N=0\)
\((2)\qquad r=n,N=1\)
\((3) \qquad r<n, \forall i\quad n_i>r,N=C_{r+k-1}^{r}\)
\((4)\qquad r<n, \exists i \quad n_i>r,N=C_{r+k-1}^{r}-\sum\limits_{1\le i\le k} C^{k-1}_{k+r-n_i-2}+\dots\)
1,2比较显然,主要是看3,4.
3:
\(\forall i\quad n_i>r\),说明不要考虑\(n_i \)的限制,直接考虑隔板法.
4:
要考虑$n_i $的限制,那么用总情况减去不合法情况.
怎样是不合法的呢?
至少有一种选的数量超过限制.
设\(S_i\)表示至少含有\(n_i+1\)个\(a_i\),且\(\mid S_i\mid=r\)
那么对于同一个\(i\),不同的\(S_i\)有\(C_{k+r-n_i-2}^{k-1}\)种.(仍然用隔板法考虑)
所以先减去这一部分,也就是式子中第一个\(\Sigma\).
但是有些选法会有两种超过限制,所以会减重复,所以要加上.
以此类推,就是容斥的式子.
三.计算
观察上面的式子,把每一个\(\Sigma\)拆开.
那么每一项就是各自独立,只跟我们选取了哪几种物品进行容斥有关.
所以我们可以枚举子集,分别计算每一项的值,复杂度\(O(2^nn)\).
还有一个要注意的地方是算组合数那里.虽然m有\(10^{12}\),但是n只有20,所以直接根据定义暴算阶乘就可以了.
还有最好预处理逆元.
真的不要Lucas,快速乘多好用
代码:
#include<bits/stdc++.h>
#define gc getchar
#define R register int
#define LL long long
#define IL inline
using namespace std;
const LL mod=1e9+7;
IL LL rd()
{
LL ans = 0,flag = 1;
char ch = gc();
while((ch>'9'||ch<'0')&&ch!='-') ch=gc();
if(ch == '-') flag=-1,ch=gc();
while(ch>='0'&&ch<='9')
ans=(ans<<3ll)+(ans<<1ll)+ch-48,ch=gc();
return 1ll*flag*ans;
}
IL LL qmul(LL x,LL y)
{
return (x*y-(LL)((long double)x/mod*y)*mod+mod)%mod;
}
const int N=22;
LL inv[N],n,s,a[N],ans;
IL LL qpow(LL x,LL a)
{
LL ans=1;
while(a)
{
if(a&1)
ans=ans*x%mod;
a>>=1,x=x*x%mod;
}
return ans;
}
void Pretreat()
{
for(R i=1;i<=21;i++)
inv[i]=qpow(i,mod-2);
}
IL LL C(LL y,LL x) //C(y,x)
{
if(x<0||y<0||y<x) return 0;
LL ans=1;
for(LL i=y;i>y-x;i--)
ans=qmul(ans,i);
for(LL i=1;i<=x;i++)
ans=qmul(ans,inv[i]);
return ans;
}
int main()
{
n=rd(),s=rd();
for(R i=1;i<=n;i++)
a[i]=rd();
Pretreat();
ans=C(n+s-1,n-1);
for(R x=1;x<(1<<n);x++)
{
int p=0;
LL tmp=0;
for(R i=0;i<n;i++)
if((x>>i)&1)
p++,tmp+=a[i+1];
if(p&1)
ans=(ans-C(n+s-tmp-p-1,n-1))%mod;
else
ans=(ans+C(n+s-tmp-p-1,n-1))%mod;
}
cout<<(ans+mod)%mod<<endl;
return 0;
}
(持续补坑...)