【题解】P4491 [HAOI2018]染色
思路
NTT 优化二项式反演。
首先考虑到求 “正好有 \(k\) 种颜色出现 \(S\) 次” 的方案数,所以可以考虑转化成求 “至少有 \(k\) 种颜色出现 \(S\) 次” 的方案数。
形式化地,令 \(F[i]\) 为前者,\(G[i]\) 为后者。显然有 \(G[i] = \sum\limits_{k = i}^m {k \choose i} F[k]\),通过二项式反演可以得到 \(F[i] = \sum\limits_{k = i}^m (-1)^{k - i} {k \choose i} G[k]\)
\(G\) 可以直接算:首先钦定 \(i\) 种出现 \(S\) 次的颜色,然后令这 \(Si\) 个元素和剩下的 \(n - Si\) 个元素进行可重排列,同时对剩下的 \(n - Si\) 个元素任意染色。
所以得到 \(G[i] = {m \choose i} \cdot \frac{n!}{(S!)^i \cdot (n - Si)!} \cdot (n - Si)^{m - i}\).
现在的问题是反演的复杂度是 \(O(n^2)\),考虑优化。
把组合数拆开得到 \(F[i] = \sum\limits_{k = i}^m (-1)^{k - i} \frac{k!}{i! \cdot (k - i)!} G[k]\).
整理一下就是 \(F[i] \cdot (i!) = \sum\limits_{k = i}^m \frac{(-1)^{k - i}}{(k - i)!} \cdot (k!) \cdot G[k]\),是差卷积的形式。
令 \(A[i] = \frac{(-1)^i}{(i!)} x^i, B[i] = (i!) \cdot G[i] x^i\),那么 \(F\) 是 \(A\) 和 \(B\) 的差卷积。
计算差卷积可以反转 \(A\) 再用 NTT 计算,最后答案也反转过来就行。
时间复杂度 \(O(n \log n)\)
代码
#include <cstdio>
#include <algorithm>
using namespace std;
typedef long long ll;
#define swap(x, y) (x ^= y ^= x ^= y)
const int maxn = 1e7 + 5;
const int ntt_sz = 3e5 + 5;
const int mod = 1004535809;
const int g = 3;
int n, m, s;
int rev[ntt_sz];
ll fac[maxn], invf[maxn];
ll F[ntt_sz], G[ntt_sz], wp[ntt_sz];
ll qpow(ll base, ll power, ll mod)
{
ll res = 1;
while (power)
{
// printf("debug %lld\n", power);
if (power & 1) res = res * base % mod;
base = base * base % mod;
power >>= 1;
}
return res;
}
ll C(int n, int m) { return (n < m ? 0ll : fac[n] * invf[m] % mod * invf[n - m] % mod); }
void calc_rev(int k) { for (int i = 1; i < k; i++) rev[i] = (rev[i >> 1] >> 1 | (i & 1 ? k >> 1 : 0)); }
void NTT(ll *A, int n)
{
calc_rev(n);
for (int i = 1; i < n; i++)
if (rev[i] > i) swap(A[i], A[rev[i]]);
for (int len = 2, m = 1; len <= n; m = len, len <<= 1)
{
ll wn = qpow(g, (mod - 1) / len, mod);
wp[0] = 1;
for (int i = 1; i <= len; i++) wp[i] = wp[i - 1] * wn % mod;
for (int l = 0, r = len - 1; r <= n; l += len, r += len)
{
int w = 0;
for (int p = l; p < l + m; p++, w++)
{
ll x = A[p], y = wp[w] * A[p + m] % mod;
A[p] = (x + y) % mod, A[p + m] = (x - y + mod) % mod;
}
}
}
}
void INTT(ll *A, int n)
{
NTT(A, n);
reverse(A + 1, A + n);
int inv = qpow(n, mod - 2, mod);
for (int i = 0; i < n; i++) A[i] = 1ll * A[i] * inv % mod;
}
int main()
{
scanf("%d%d%d", &n, &m, &s);
int lim = min(m, n / s);
fac[0] = invf[0] = 1;
for (int i = 1; i <= max(n, m); i++) fac[i] = fac[i - 1] * i % mod;
invf[max(n, m)] = qpow(fac[max(n, m)], mod - 2, mod);
for (int i = max(n, m) - 1; i; i--) invf[i] = invf[i + 1] * (i + 1) % mod;
for (int i = 0; i <= lim; i++)
{
// printf("debug %lld\n", s - n * i);
F[i] = C(m, i) * fac[n] % mod * qpow(invf[s], i, mod) % mod * invf[n - s * i] % mod * qpow(m - i, n - s * i, mod) % mod;
F[i] = F[i] * fac[i] % mod, G[i] = (i & 1) ? mod - invf[i] : invf[i];
}
reverse(F, F + lim + 1);
int k = 1;
while (k < (lim * 2 + 2)) k <<= 1;
NTT(F, k), NTT(G, k);
for (int i = 0; i < k; i++) F[i] = F[i] * G[i] % mod;
INTT(F, k);
reverse(F, F + lim + 1);
ll ans = 0;
for (int i = 0, w; i <= lim; i++)
{
scanf("%d", &w);
ans = (ans + F[i] * invf[i] % mod * w) % mod;
}
printf("%lld\n", ans);
return 0;
}