「学习笔记」二项式反演
概念
二项式反演其实就是利用容斥的思想处理一些通过求“至少或至多”来解决“恰好”的问题。
形式
其中,形式三比较常用,组合意义为 \(f(n)\) 表示“至少选 \(n\) 个”,\(g(n)\) 表示“恰好选 \(n\) 个”。
例题
Luogu P4859 已经没有什么好害怕的了
Description
给定两个长为 \(n\) 的序列 \(a,b\),它们两两配对,求配对后 \(a>b\) 的组数比 \(b>a\) 的组数恰好多 \(k\) 组的方案数。
\(1\le n \le 2000,0\le k\le n\)
Solution
题目要求“恰好多 \(k\) 组”,共有 \(n\) 组,所以相当于 \(a>b\) 恰好 \(\dfrac {n+k}2\) 组。
设 \(dp_{i,j}\) 表示前 \(i\) 个数中,有 \(j\) 组 \(a>b\) 的方案数,转移方程为
其中,\(cnt_i\) 表示 \(b\) 中比 \(a_i\) 小的数的个数,这个可以将 \(a,b\) 排序后双指针扫
接下来,记 \(f_i=dp_{n,i}\times (n-i)!\),也就是至少 \(i\) 组的方案数
然后根据二项式反演就可以得到恰好 \(k\) 组的方案数 \(g_k\)
Code
int n, k, a[N], b[N], cnt[N];
ll fac[N], dp[N][N], f[N], g[N];
ll qpow(ll a, int b)
{
ll res = 1;
while(b)
{
if(b & 1) res = res * a % mod;
a = a * a % mod, b >>= 1;
}
return res;
}
ll add(ll x) {return x < mod ? x : x - mod;}
ll inv(ll x) {return qpow(x, mod - 2);}
ll C(int n, int m) {return n < m ? 0 : fac[n] * inv(fac[m]) % mod * inv(fac[n - m]) % mod;}
int main()
{
read(n), read(k);
if((n + k) & 1)
{
puts("0");
return 0;
}
k = (n + k) >> 1;
for(int i = 1; i <= n; i++) read(a[i]);
for(int i = 1; i <= n; i++) read(b[i]);
sort(a + 1, a + 1 + n);
sort(b + 1, b + 1 + n);
fac[0] = 1;
for(int i = 1, j = 1; i <= n; i++)
{
while(j <= n && a[i] > b[j]) j++;
cnt[i] = j - 1;
fac[i] = fac[i - 1] * i % mod;
}
dp[0][0] = 1;
for(int i = 1; i <= n; i++)
for(int j = 0; j <= i; j++)
dp[i][j] = add(dp[i - 1][j] + (!j ? 0 : dp[i - 1][j - 1] * (cnt[i] - j + 1) % mod));
for(int i = 0; i <= n; i++) f[i] = dp[n][i] * fac[n - i] % mod;
for(int i = 1; i <= n; i++)
for(int j = k; j <= n; j++)
g[i] = add(g[i] + add((((j - k) & 1) ? -1 : 1) * f[j] * C(j, k) % mod + mod));
write(g[k]), pc('\n');
return 0;
}
// A.S.
Luogu P4491 [HAOI2018]染色
Description
有一个长为 \(n\) 的序列,每个位置都可以是 \([1,m]\) 中的某一个数,若这 \(n\) 个数中恰好出现了 \(s\) 次的数有 \(k\) 个,那么会得到 \(w_k\) 的贡献。
求对于所有可能的情况,能获得的权值的和对 \(1004535809\) 取模的结果是多少。
\(1\le n\le 10^7,1\le m \le 10^5,0\le s\le 150,0\le w_i\le 1004535809\)
Solution
显然数的个数不会超过 \(cnt=\min(m,n/s)\)
依然是恰好出现 \(s\) 次,考虑计算有 \(i\) 个数至少出现 \(s\) 次的方案数 \(f_i\)
钦定 \(i\) 个数出现了 \(s\) 次,剩下的 \(n-is\) 个位置在 \(m-i\) 个数中随便选
然后进行二项式反演,设 \(g_k\) 表示有 \(k\) 个数恰好出现 \(s\) 次
到这里就能看出来卷积的形式了
设
那么 \(g_i=\dfrac{(F*G)(i)}{i!}\)
NTT 计算卷积即可
Code
#include <bits/stdc++.h>
#define ll long long
#define db double
#define gc getchar
#define pc putchar
using namespace std;
namespace IO
{
template <typename T>
void read(T &x)
{
x = 0; bool f = 0; char c = gc();
while(!isdigit(c)) f |= c == '-', c = gc();
while(isdigit(c)) x = x * 10 + c - '0', c = gc();
if(f) x = -x;
}
template <typename T>
void write(T x)
{
if(x < 0) pc('-'), x = -x;
if(x > 9) write(x / 10);
pc('0' + x % 10);
}
}
using namespace IO;
const int MAXN = 1e7 + 5;
const int N = 1e5 + 5;
const int mod = 1004535809;
const int G = 3;
const int Gi = 334845270;
ll add(ll x) {return x < mod ? x : x - mod;}
ll sub(ll x) {return x < 0 ? x + mod : x;}
ll qpow(ll a, int b)
{
ll res = 1;
while(b)
{
if(b & 1) res = res * a % mod;
a = a * a % mod, b >>= 1;
}
return res;
}
ll inv(ll x) {return qpow(x, mod - 2);}
ll fac[MAXN], f[N << 2], g[N << 2];
ll C(int n, int m)
{
return n < m ? 0 : fac[n] * inv(fac[m]) % mod * inv(fac[n - m]) % mod;
}
int rev[N << 2];
int calclim(int n)
{
int lim = 1;
while(lim < n) lim <<= 1;
return lim;
}
void calcrev(int lim)
{
for(int i = 0; i < lim; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) * (lim >> 1));
}
void NTT(ll *a, int lim, int type)
{
for(int i = 0; i < lim; i++)
if(i < rev[i]) swap(a[i], a[rev[i]]);
for(int mid = 1; mid < lim; mid <<= 1)
{
ll wn = qpow(type == 1 ? G : Gi, (mod - 1) / (mid << 1));
for(int i = 0; i < lim; i += (mid << 1))
{
ll w = 1;
for(int j = 0; j < mid; j++, w = w * wn % mod)
{
ll x = a[i + j], y = w * a[i + mid + j] % mod;
a[i + j] = add(x + y);
a[i + mid + j] = sub(x - y);
}
}
}
if(type == -1)
{
ll limi = qpow(lim, mod - 2);
for(int i = 0; i < lim; i++) a[i] = a[i] * limi % mod;
}
return;
}
int main()
{
int n, m, s;
read(n), read(m), read(s);
int cnt = min(m, n / s) + 1;
fac[0] = 1;
for(int i = 1; i < MAXN; i++)
fac[i] = fac[i - 1] * i % mod;
for(int i = 0; i < cnt; i++)
{
f[i] = fac[i] * C(m, i) % mod * fac[n] % mod * inv(qpow(fac[s], i)) % mod * inv(fac[n - s * i]) % mod * qpow(m - i, n - s * i) % mod;
g[i] = (i & 1) ? mod - inv(fac[i]) : inv(fac[i]);
}
reverse(f, f + cnt);
int lim = calclim(cnt << 1);
calcrev(lim);
NTT(f, lim, 1), NTT(g, lim, 1);
for(int i = 0; i < lim; i++) f[i] = f[i] * g[i] % mod;
NTT(f, lim, -1);
reverse(f, f + cnt);
ll ans = 0;
for(int i = 0, w; i < cnt; i++)
read(w), ans = add(ans + inv(fac[i]) * f[i] % mod * w % mod);
write(ans), pc('\n');
return 0;
}
// A.S.