Loading

【题解】P4491 [HAOI2018]染色

思路

NTT 优化二项式反演。

首先考虑到求 “正好有 \(k\) 种颜色出现 \(S\) 次” 的方案数,所以可以考虑转化成求 “至少有 \(k\) 种颜色出现 \(S\) 次” 的方案数。

形式化地,令 \(F[i]\) 为前者,\(G[i]\) 为后者。显然有 \(G[i] = \sum\limits_{k = i}^m {k \choose i} F[k]\),通过二项式反演可以得到 \(F[i] = \sum\limits_{k = i}^m (-1)^{k - i} {k \choose i} G[k]\)

\(G\) 可以直接算:首先钦定 \(i\) 种出现 \(S\) 次的颜色,然后令这 \(Si\) 个元素和剩下的 \(n - Si\) 个元素进行可重排列,同时对剩下的 \(n - Si\) 个元素任意染色。

所以得到 \(G[i] = {m \choose i} \cdot \frac{n!}{(S!)^i \cdot (n - Si)!} \cdot (n - Si)^{m - i}\).

现在的问题是反演的复杂度是 \(O(n^2)\),考虑优化。

把组合数拆开得到 \(F[i] = \sum\limits_{k = i}^m (-1)^{k - i} \frac{k!}{i! \cdot (k - i)!} G[k]\).

整理一下就是 \(F[i] \cdot (i!) = \sum\limits_{k = i}^m \frac{(-1)^{k - i}}{(k - i)!} \cdot (k!) \cdot G[k]\),是差卷积的形式。

\(A[i] = \frac{(-1)^i}{(i!)} x^i, B[i] = (i!) \cdot G[i] x^i\),那么 \(F\)\(A\)\(B\) 的差卷积。

计算差卷积可以反转 \(A\) 再用 NTT 计算,最后答案也反转过来就行。

时间复杂度 \(O(n \log n)\)

代码

#include <cstdio>
#include <algorithm>
using namespace std;

typedef long long ll;

#define swap(x, y) (x ^= y ^= x ^= y)

const int maxn = 1e7 + 5;
const int ntt_sz = 3e5 + 5;
const int mod = 1004535809;
const int g = 3;

int n, m, s;
int rev[ntt_sz];
ll fac[maxn], invf[maxn];
ll F[ntt_sz], G[ntt_sz], wp[ntt_sz];

ll qpow(ll base, ll power, ll mod)
{
    ll res = 1;
    while (power)
    {
        // printf("debug %lld\n", power);
        if (power & 1) res = res * base % mod;
        base = base * base % mod;
        power >>= 1;
    }
    return res;
}

ll C(int n, int m) { return (n < m ? 0ll : fac[n] * invf[m] % mod * invf[n - m] % mod); }

void calc_rev(int k) { for (int i = 1; i < k; i++) rev[i] = (rev[i >> 1] >> 1 | (i & 1 ? k >> 1 : 0)); }

void NTT(ll *A, int n)
{
    calc_rev(n);
    for (int i = 1; i < n; i++)
        if (rev[i] > i) swap(A[i], A[rev[i]]);
    for (int len = 2, m = 1; len <= n; m = len, len <<= 1)
    {
        ll wn = qpow(g, (mod - 1) / len, mod);
        wp[0] = 1;
        for (int i = 1; i <= len; i++) wp[i] = wp[i - 1] * wn % mod;
        for (int l = 0, r = len - 1; r <= n; l += len, r += len)
        {
            int w = 0;
            for (int p = l; p < l + m; p++, w++)
            {
                ll x = A[p], y = wp[w] * A[p + m] % mod;
                A[p] = (x + y) % mod, A[p + m] = (x - y + mod) % mod;
            }
        }
    }
}

void INTT(ll *A, int n)
{
    NTT(A, n);
    reverse(A + 1, A + n);
    int inv = qpow(n, mod - 2, mod);
    for (int i = 0; i < n; i++) A[i] = 1ll * A[i] * inv % mod;
}

int main()
{
    scanf("%d%d%d", &n, &m, &s);
    int lim = min(m, n / s);
    fac[0] = invf[0] = 1;
    for (int i = 1; i <= max(n, m); i++) fac[i] = fac[i - 1] * i % mod;
    invf[max(n, m)] = qpow(fac[max(n, m)], mod - 2, mod);
    for (int i = max(n, m) - 1; i; i--) invf[i] = invf[i + 1] * (i + 1) % mod;
    for (int i = 0; i <= lim; i++)
    {
        // printf("debug %lld\n", s - n * i);
        F[i] = C(m, i) * fac[n] % mod * qpow(invf[s], i, mod) % mod * invf[n - s * i] % mod * qpow(m - i, n - s * i, mod) % mod;
        F[i] = F[i] * fac[i] % mod, G[i] = (i & 1) ? mod - invf[i] : invf[i];
    }
    reverse(F, F + lim + 1);
    int k = 1;
    while (k < (lim * 2 + 2)) k <<= 1;
    NTT(F, k), NTT(G, k);
    for (int i = 0; i < k; i++) F[i] = F[i] * G[i] % mod;
    INTT(F, k);
    reverse(F, F + lim + 1);
    ll ans = 0;
    for (int i = 0, w; i <= lim; i++)
    {
        scanf("%d", &w);
        ans = (ans + F[i] * invf[i] % mod * w) % mod;
    }
    printf("%lld\n", ans);
    return 0;
}
posted @ 2023-01-29 09:28  kymru  阅读(83)  评论(0编辑  收藏  举报