Loading

【题解】P5824 十二重计数法 / 球盒问题

球盒问题全家桶。

禁忌「十二重存在」

给定 \(n\) 个球和 \(m\) 个盒,求在 12 种不同的限制条件下把球放完的方案数。

思路

排列组合 + 斯特林数。

前置知识:【题解】P4389 付公主的背包 / Euler 变换

其壹

球可区分,盒可区分,数量无限制。

盒相当于颜色,放球相当于给小球染色,共有 \(m^n\) 种染色方案。

球可区分,盒可区分,每盒至多一个。

\(n < m\) 时无解。

反之可以按顺序考虑每个球放进的盒子,方案数是 \(m^{\underline{n}}\).

球可区分,盒可区分,每盒至少一个。

考虑先分进 \(m\) 个盒子再钦定顺序,答案是 \(m! {n \brace m}\).

其贰

球可区分,盒不可区分,数量无限制。

枚举非空盒子的数量,答案是 \(\sum\limits_{i = 0}^m {n \brace i}\).

球可区分,盒不可区分,每盒至多一个。

此时所有合法方案都是等价的,答案是 \([n \leq m]\).

球可区分,盒不可区分,每盒至少一个。

根据第二类斯特林数的定义为 \({n \brace m}\).

其叁

球不可区分,盒可区分,数量无限制。

可以转化成:对于一个 \((n + 1) \times m\) 的网格,向下走一步等价于放一个球,向右走一步等价于切换到下一个盒子,最终到达 \((n + 1, m)\) 的方案总数。

答案是 \({n + m - 1 \choose n}\).

球不可区分,盒可区分,每盒至多一个。

等价于选出 \(n\) 个非空的盒子,答案是 \({m \choose n}\).

球不可区分,盒可区分,每盒至少一个。

插板法:\({n - 1 \choose m - 1}\).

其肆

球不可区分,盒不可区分,数量无限制。

整数拆分:\(p(n, m)\),将 \(n\) 拆成 \(m\) 个无序自然数的方案数。

【题解】P4389 付公主的背包 / Euler 变换

拾壹

球不可区分,盒不可区分,每盒至多一个。

同伍得所有方案都是等价的,答案是 \([n \leq m]\).

拾贰

球不可区分,盒不可区分,每盒至少一个。

先给每个盒子放一个,此后同拾,为 \(p(n - m, m)\)

总复杂度 \(O(n \log n)\)

代码

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

typedef long long ll;

const int sz = 2e6 + 5;
const int mod = 998244353;
const int g = 3;

int n, m, ans[15];
int cnt[sz], rev[sz], inv[sz], fac[sz], invf[sz];
ll F[sz], G[sz], wp[sz];
ll Ft[sz], Rt[sz], ft[sz], rt[sz], Fn[sz], Rn[sz];

void calc_rev(int k) { for (int i = 1; i < k; i++) rev[i] = (rev[i >> 1] >> 1 | (i & 1 ? k >> 1 : 0)); }

ll qpow(ll base, ll power = mod - 2, ll mod = mod)
{
    ll res = 1;
    while (power)
    {
        if (power & 1) res = res * base % mod;
        base = base * base % mod;
        power >>= 1;
    }
    return res;
}

void NTT(ll *A, int n)
{
    calc_rev(n);
    for (int i = 1; i < n; i++)
        if (rev[i] > i) swap(A[i], A[rev[i]]);
    for (int len = 2, m = 1; len <= n; m = len, len <<= 1)
    {
        ll wn = qpow(g, (mod - 1) / len, mod);
        wp[0] = 1;
        for (int i = 1; i <= len; i++) wp[i] = wp[i - 1] * wn % mod;
        for (int l = 0, r = len - 1; r <= n; l += len, r += len)
        {
            int w = 0;
            for (int p = l; p < l + m; p++, w++)
            {
                ll x = A[p], y = wp[w] * A[p + m] % mod;
                A[p] = (x + y) % mod, A[p + m] = (x - y + mod) % mod;
            }
        }
    }
}

void INTT(ll *A, int n)
{
    NTT(A, n);
    reverse(A + 1, A + n);
    int inv = qpow(n, mod - 2, mod);
    for (int i = 0; i < n; i++) A[i] = 1ll * A[i] * inv % mod;
}

void times(ll *f, ll *g, int len1, int len2, int lim)
{
    int m = len1 + len2 - 1, k = 1;
    while (k <= m) k <<= 1;
    for (int i = 0; i < len2; i++) Rt[i] = g[i];
    NTT(f, k), NTT(Rt, k);
    for (int i = 0; i < k; i++) f[i] = f[i] * Rt[i] % mod;
    INTT(f, k);
    for (int i = lim; i < n; i++) f[i] = 0;
    for (int i = 0; i < k; i++) Rt[i] = 0;
}

void invp(ll *f, ll *r, int n)
{
    int k = 1;
    while (k < n) k <<= 1;
    r[0] = qpow(f[0], mod - 2, mod);
    for (int len = 2, m = 1; len <= k; m = len, len <<= 1)
    {
        for (int i = 0; i < len; i++) Rt[i] = r[i], Ft[i] = f[i];
        NTT(Ft, len), NTT(Rt, len);
        for (int i = 0; i < len; i++) Rt[i] = Rt[i] * Ft[i] % mod;
        INTT(Rt, len);
        for (int i = 0; i < m; i++) Rt[i] = 0; Rt[0] = 1;
        for (int i = 0; i < len; i++) Ft[i] = r[i];
        NTT(Ft, len), NTT(Rt, len);
        for (int i = 0; i < len; i++) Rt[i] = Rt[i] * Ft[i] % mod;
        INTT(Rt, len);
        for (int i = m; i < len; i++) r[i] = (r[i] * 2ll - Rt[i] + mod) % mod;
    }
    memset(Ft, 0, k * sizeof(ll));
    memset(Rt, 0, k * sizeof(ll));
    for (int i = n; i < k; i++) r[i] = 0;
}

void diffp(ll *f, ll *der, int n)
{
    for (int i = 1; i < n; i++) der[i - 1] = f[i] * i % mod;
    der[n - 1] = 0;
}

void intep(ll *f, ll *inte, int n)
{
    for (int i = 1; i < n; i++)
        if (!inv[i]) inv[i] = (i == 1 ? 1 : inv[mod % i] * (mod - mod / i) % mod);
    for (int i = 1; i < n; i++) inte[i] = f[i - 1] * inv[i] % mod;
    inte[0] = 0;
}

void lnp(ll *f, ll *ln, int n)
{
    diffp(f, ft, n), invp(f, rt, n);
    int k = 1;
    while (k < (n << 1)) k <<= 1;
    NTT(ft, k), NTT(rt, k);
    for (int i = 0; i < k; i++) ft[i] = ft[i] * rt[i] % mod;
    INTT(ft, k);
    intep(ft, ln, n);
    for (int i = 0; i < k; i++) ft[i] = rt[i] = 0;
}

void expp(ll *f, ll *exp, int n)
{
    int k = 1;
    while (k < n) k <<= 1;
    exp[0] = 1;
    for (int len = 2, m = 1; len <= k; m = len, len <<= 1)
    {
        for (int i = 0; i < len; i++) Fn[i] = exp[i];
        lnp(Fn, Rn, len);
        for (int i = 0; i < len; i++) Rn[i] = (f[i] - Rn[i] + mod) % mod;
        Rn[0] = (Rn[0] + 1) % mod;
        NTT(Fn, len << 1), NTT(Rn, len << 1);
        for (int i = 0; i < (len << 1); i++) Fn[i] = Fn[i] * Rn[i] % mod;
        INTT(Fn, len << 1);
        for (int i = 0; i < len; i++) exp[i] = Fn[i];
    }
    for (int i = 0; i < (k << 1); i++) Fn[i] = Rn[i] = 0;
    for (int i = n; i < k; i++) exp[i] = 0;
}

void init(int lim)
{
    fac[0] = inv[0] = invf[0] = fac[1] = inv[1] = invf[1] = 1;
    for (int i = 2; i <= lim; i++) fac[i] = 1ll * fac[i - 1] * i % mod, inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod;
    for (int i = 2; i <= lim; i++) invf[i] = 1ll * invf[i - 1] * inv[i] % mod;
}

int C(int n, int m) { return (n < m ? 0 : 1ll * fac[n] * invf[m] % mod * invf[n - m] % mod); }

void solve1()
{
    ans[1] = qpow(m, n);
    if (n <= m)
    {
        ans[5] = ans[11] = 1;
        ans[2] = 1ll * fac[m] * invf[m - n] % mod;
        ans[8] = C(m, n);
    }
    ans[7] = invf[n];
    // printf("done %d\n", ans[7]);
    if (n + m - 1 >= m) ans[7] = 1ll * ans[7] * fac[n + m - 1] % mod * invf[m - 1] % mod;
    if (n >= m) ans[9] = C(n - 1, m - 1);
}

void solve2()
{
    for (int i = 0; i <= min(n, m); i++)
    {
        F[i] = qpow(i, n) * invf[i] % mod;
        G[i] = (i & 1 ? mod - invf[i] : invf[i]);
    }
    times(F, G, min(n, m) + 1, min(n, m) + 1, min(n, m) + 1);
    if (n >= m) ans[3] = 1ll * F[m] * fac[m] % mod;
    for (int i = 1; i <= min(n, m); i++) ans[4] = (ans[4] + F[i]) % mod;
    ans[6] = (n < m ? 0 : F[m]);
    memset(F, 0, sizeof(F)), memset(G, 0, sizeof(G));
}

void solve3()
{
    for (int i = 1; i <= m; i++)
        for (int j = 0; j <= n; j += i)
            F[j] = (F[j] + 1ll * invf[j / i] * fac[j / i - 1] % mod) % mod;
    expp(F, G, n + 1);
    ans[10] = G[n];
    if (n >= m) ans[12] = G[n - m];
}

int main()
{
    scanf("%d%d", &n, &m);
    init(max(n, m) << 1), solve1(), solve2(), solve3();
    // printf("debug %d\n", invf[n]);
    for (int i = 1; i <= 12; i++) printf("%d\n", ans[i]);
    return 0;
}
posted @ 2023-02-20 17:03  kymru  阅读(59)  评论(0编辑  收藏  举报