Luogu 4491 [HAOI2018]染色
BZOJ 5306
考虑计算恰好出现$s$次的颜色有$k$种的方案数。
首先可以设$lim = min(m, \left \lfloor \frac{n}{s} \right \rfloor)$,我们在计算的时候只要算到这个$lim$就可以了。
设$f(k)$表示出现$s$次的颜色至少有$k$种的方案数,则
$$f(k) = \binom{m}{k}\binom{n}{ks}\frac{(ks)!}{(s!)^k}(m - k)^{n - ks}$$
就是先选出$k$个颜色和$ks$个格子放这些颜色,这样子总的方案数是全排列除以限排列,剩下的颜色随便放。
处理一下组合数,整个$f$可以在$O(nlogn)$时间内算出来。
设$g(k)$表示出现$s$次的颜色刚好有$k$种的方案数,考虑到对于$\forall i < j$,$g(j)$在$f(i)$中被计算了$\binom{j}{i}$次,所以有
$$f(k) = \sum_{i = k}^{lim}\binom{i}{k}g(i)$$
直接二项式反演回来,
$$g(k) = \sum_{i = k}^{lim}(-1)^{i - k}\binom{i}{k}f(i)$$
拆开组合数,
$$g(k) = \sum_{i = k}^{lim}(-1)^{i - k}\frac{i!}{k!(i - k)!}f(i) = \frac{1}{k!}\sum_{i = k}^{lim}\frac{(-1)^{i - k}}{(i - k)!}(f(i) * (i!))$$
设$A(i) = f(i) * (i!)$,$B(i) = \frac{(-1)^{i}}{i!}$
$$g(k)* (k!) = \sum_{i = k}^{lim}A(i)B(i - k)$$
咕,并不是卷积。
把$B$翻转,再设$B'(i) = B(lim - i)$
$$g(k)* (k!) = \sum_{i = k}^{lim}A(i)B'(lim + k - i)$$
注意到$A*B'$的第$lim + k$项的系数是$\sum_{i = 0}^{lim + k}A(i)B'(lim + k - i)$,但是需要满足
$$ 0 \leq i \leq lim$$
$$ 0 \leq lim + k - i \leq lim $$
$$k \leq i \leq lim$$
刚好满足。
时间复杂度$O(n + mlogm)$。
Code:
#include <cstdio> #include <cstring> #include <algorithm> #include <vector> using namespace std; typedef long long ll; typedef vector <ll> poly; const int N = 1e7 + 5; const int M = 1e5 + 5; int n, m, s; ll w[M], fac[N], ifac[N], f[M], g[M], sum[M]; inline void deb(poly c) { for (int i = 0; i < (int)c.size(); i++) printf("%lld%c", c[i], " \n"[i == (int)c.size() - 1]); } namespace Poly { const int L = 1 << 18; const ll P = 1004535809LL; int lim, pos[L]; template <typename T> inline void inc(T &x, T y) { x += y; if (x >= P) x -= P; } template <typename T> inline void sub(T &x, T y) { x -= y; if (x < 0) x += P; } inline void reverse(poly &c) { for (int i = 0, j = (int)c.size() - 1; i < j; i++, j--) swap(c[i], c[j]); } inline ll fpow(ll x, ll y) { ll res = 1; for (; y > 0; y >>= 1) { if (y & 1) res = res * x % P; x = x * x % P; } return res; } inline void prework(int len) { int l = 0; for (lim = 1; lim < len; lim <<= 1, ++l); for (int i = 0; i < lim; i++) pos[i] = (pos[i >> 1] >> 1) | ((i & 1) << (l - 1)); } inline void ntt(poly &c, int opt) { c.resize(lim, 0); for (int i = 0; i < lim; i++) if (i < pos[i]) swap(c[i], c[pos[i]]); for (int i = 1; i < lim; i <<= 1) { ll wn = fpow(3, (P - 1) / (i << 1)); if (opt == -1) wn = fpow(wn, P - 2); for (int len = i << 1, j = 0; j < lim; j += len) { ll w = 1; for (int k = 0; k < i; k++, w = w * wn % P) { ll x = c[j + k], y = c[j + k + i] * w % P; c[j + k] = (x + y) % P, c[j + k + i] = (x - y + P) % P; } } } if (opt == -1) { ll inv = fpow(lim, P - 2); for (int i = 0; i < lim; i++) c[i] = c[i] * inv % P; } } inline poly mul(const poly x, const poly y) { poly u = x, v = y, res; prework(x.size() + y.size() - 1); ntt(u, 1), ntt(v, 1); for (int i = 0; i < lim; i++) res.push_back(u[i] * v[i] % P); ntt(res, -1); res.resize(x.size() + y.size() - 1); return res; } } using Poly :: P; using Poly :: fpow; using Poly :: mul; using Poly :: inc; using Poly :: reverse; template <typename T> inline void read(T &X) { X = 0; char ch = 0; T op = 1; for (; ch > '9'|| ch < '0'; ch = getchar()) if (ch == '-') op = -1; for (; ch >= '0' && ch <= '9'; ch = getchar()) X = (X << 3) + (X << 1) + ch - 48; X *= op; } inline void prework(int len) { fac[0] = 1; for (int i = 1; i <= len; i++) fac[i] = fac[i - 1] * i % P; ifac[len] = fpow(fac[len], P - 2); for (int i = len - 1; i >= 0; i--) ifac[i] = ifac[i + 1] * (i + 1) % P; } inline ll getC(int x, int y) { return fac[x] * ifac[y] % P * ifac[x - y] % P; } int main() { read(n), read(m), read(s); for (int i = 0; i <= m; i++) read(w[i]); int rep = min(m, n / s); prework(max(n, m)); for (int i = 0; i <= rep; i++) f[i] = getC(m, i) * getC(n, i * s) % P * fac[i * s] % P * fpow(ifac[s], i) % P * fpow(m - i, n - i * s) % P; /* for (int i = 0; i <= rep; i++) printf("%lld%c", f[i], " \n"[i == rep]); */ poly a, b; a.resize(rep + 1, 0), b.resize(rep + 1, 0); for (int i = 0; i < rep + 1; i++) { a[i] = f[i] * fac[i] % P; b[i] = ifac[i]; if (i & 1) b[i] = (P - b[i]) % P; } reverse(b); // deb(a), deb(b); poly c = mul(a, b); // deb(c); /* for (int i = rep; i >= 0; i--) { sum[i] = sum[i + 1]; inc(sum[i], c[i]); } for (int i = 0; i <= rep; i++) printf("%lld%c", sum[i], " \n"[i == rep]); */ ll ans = 0; for (int i = 0; i <= rep; i++) { g[i] = ifac[i] * c[rep + i] % P; inc(ans, w[i] * g[i] % P); } printf("%lld\n", ans); return 0; }