【题解】P5824 十二重计数法 / 球盒问题
球盒问题全家桶。
禁忌「十二重存在」
给定 \(n\) 个球和 \(m\) 个盒,求在 12 种不同的限制条件下把球放完的方案数。
思路
排列组合 + 斯特林数。
前置知识:【题解】P4389 付公主的背包 / Euler 变换
其壹
壹
球可区分,盒可区分,数量无限制。
盒相当于颜色,放球相当于给小球染色,共有 \(m^n\) 种染色方案。
贰
球可区分,盒可区分,每盒至多一个。
\(n < m\) 时无解。
反之可以按顺序考虑每个球放进的盒子,方案数是 \(m^{\underline{n}}\).
叁
球可区分,盒可区分,每盒至少一个。
考虑先分进 \(m\) 个盒子再钦定顺序,答案是 \(m! {n \brace m}\).
其贰
肆
球可区分,盒不可区分,数量无限制。
枚举非空盒子的数量,答案是 \(\sum\limits_{i = 0}^m {n \brace i}\).
伍
球可区分,盒不可区分,每盒至多一个。
此时所有合法方案都是等价的,答案是 \([n \leq m]\).
陆
球可区分,盒不可区分,每盒至少一个。
根据第二类斯特林数的定义为 \({n \brace m}\).
其叁
柒
球不可区分,盒可区分,数量无限制。
可以转化成:对于一个 \((n + 1) \times m\) 的网格,向下走一步等价于放一个球,向右走一步等价于切换到下一个盒子,最终到达 \((n + 1, m)\) 的方案总数。
答案是 \({n + m - 1 \choose n}\).
捌
球不可区分,盒可区分,每盒至多一个。
等价于选出 \(n\) 个非空的盒子,答案是 \({m \choose n}\).
玖
球不可区分,盒可区分,每盒至少一个。
插板法:\({n - 1 \choose m - 1}\).
其肆
拾
球不可区分,盒不可区分,数量无限制。
整数拆分:\(p(n, m)\),将 \(n\) 拆成 \(m\) 个无序自然数的方案数。
拾壹
球不可区分,盒不可区分,每盒至多一个。
同伍得所有方案都是等价的,答案是 \([n \leq m]\).
拾贰
球不可区分,盒不可区分,每盒至少一个。
先给每个盒子放一个,此后同拾,为 \(p(n - m, m)\)。
总复杂度 \(O(n \log n)\)
代码
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const int sz = 2e6 + 5;
const int mod = 998244353;
const int g = 3;
int n, m, ans[15];
int cnt[sz], rev[sz], inv[sz], fac[sz], invf[sz];
ll F[sz], G[sz], wp[sz];
ll Ft[sz], Rt[sz], ft[sz], rt[sz], Fn[sz], Rn[sz];
void calc_rev(int k) { for (int i = 1; i < k; i++) rev[i] = (rev[i >> 1] >> 1 | (i & 1 ? k >> 1 : 0)); }
ll qpow(ll base, ll power = mod - 2, ll mod = mod)
{
ll res = 1;
while (power)
{
if (power & 1) res = res * base % mod;
base = base * base % mod;
power >>= 1;
}
return res;
}
void NTT(ll *A, int n)
{
calc_rev(n);
for (int i = 1; i < n; i++)
if (rev[i] > i) swap(A[i], A[rev[i]]);
for (int len = 2, m = 1; len <= n; m = len, len <<= 1)
{
ll wn = qpow(g, (mod - 1) / len, mod);
wp[0] = 1;
for (int i = 1; i <= len; i++) wp[i] = wp[i - 1] * wn % mod;
for (int l = 0, r = len - 1; r <= n; l += len, r += len)
{
int w = 0;
for (int p = l; p < l + m; p++, w++)
{
ll x = A[p], y = wp[w] * A[p + m] % mod;
A[p] = (x + y) % mod, A[p + m] = (x - y + mod) % mod;
}
}
}
}
void INTT(ll *A, int n)
{
NTT(A, n);
reverse(A + 1, A + n);
int inv = qpow(n, mod - 2, mod);
for (int i = 0; i < n; i++) A[i] = 1ll * A[i] * inv % mod;
}
void times(ll *f, ll *g, int len1, int len2, int lim)
{
int m = len1 + len2 - 1, k = 1;
while (k <= m) k <<= 1;
for (int i = 0; i < len2; i++) Rt[i] = g[i];
NTT(f, k), NTT(Rt, k);
for (int i = 0; i < k; i++) f[i] = f[i] * Rt[i] % mod;
INTT(f, k);
for (int i = lim; i < n; i++) f[i] = 0;
for (int i = 0; i < k; i++) Rt[i] = 0;
}
void invp(ll *f, ll *r, int n)
{
int k = 1;
while (k < n) k <<= 1;
r[0] = qpow(f[0], mod - 2, mod);
for (int len = 2, m = 1; len <= k; m = len, len <<= 1)
{
for (int i = 0; i < len; i++) Rt[i] = r[i], Ft[i] = f[i];
NTT(Ft, len), NTT(Rt, len);
for (int i = 0; i < len; i++) Rt[i] = Rt[i] * Ft[i] % mod;
INTT(Rt, len);
for (int i = 0; i < m; i++) Rt[i] = 0; Rt[0] = 1;
for (int i = 0; i < len; i++) Ft[i] = r[i];
NTT(Ft, len), NTT(Rt, len);
for (int i = 0; i < len; i++) Rt[i] = Rt[i] * Ft[i] % mod;
INTT(Rt, len);
for (int i = m; i < len; i++) r[i] = (r[i] * 2ll - Rt[i] + mod) % mod;
}
memset(Ft, 0, k * sizeof(ll));
memset(Rt, 0, k * sizeof(ll));
for (int i = n; i < k; i++) r[i] = 0;
}
void diffp(ll *f, ll *der, int n)
{
for (int i = 1; i < n; i++) der[i - 1] = f[i] * i % mod;
der[n - 1] = 0;
}
void intep(ll *f, ll *inte, int n)
{
for (int i = 1; i < n; i++)
if (!inv[i]) inv[i] = (i == 1 ? 1 : inv[mod % i] * (mod - mod / i) % mod);
for (int i = 1; i < n; i++) inte[i] = f[i - 1] * inv[i] % mod;
inte[0] = 0;
}
void lnp(ll *f, ll *ln, int n)
{
diffp(f, ft, n), invp(f, rt, n);
int k = 1;
while (k < (n << 1)) k <<= 1;
NTT(ft, k), NTT(rt, k);
for (int i = 0; i < k; i++) ft[i] = ft[i] * rt[i] % mod;
INTT(ft, k);
intep(ft, ln, n);
for (int i = 0; i < k; i++) ft[i] = rt[i] = 0;
}
void expp(ll *f, ll *exp, int n)
{
int k = 1;
while (k < n) k <<= 1;
exp[0] = 1;
for (int len = 2, m = 1; len <= k; m = len, len <<= 1)
{
for (int i = 0; i < len; i++) Fn[i] = exp[i];
lnp(Fn, Rn, len);
for (int i = 0; i < len; i++) Rn[i] = (f[i] - Rn[i] + mod) % mod;
Rn[0] = (Rn[0] + 1) % mod;
NTT(Fn, len << 1), NTT(Rn, len << 1);
for (int i = 0; i < (len << 1); i++) Fn[i] = Fn[i] * Rn[i] % mod;
INTT(Fn, len << 1);
for (int i = 0; i < len; i++) exp[i] = Fn[i];
}
for (int i = 0; i < (k << 1); i++) Fn[i] = Rn[i] = 0;
for (int i = n; i < k; i++) exp[i] = 0;
}
void init(int lim)
{
fac[0] = inv[0] = invf[0] = fac[1] = inv[1] = invf[1] = 1;
for (int i = 2; i <= lim; i++) fac[i] = 1ll * fac[i - 1] * i % mod, inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod;
for (int i = 2; i <= lim; i++) invf[i] = 1ll * invf[i - 1] * inv[i] % mod;
}
int C(int n, int m) { return (n < m ? 0 : 1ll * fac[n] * invf[m] % mod * invf[n - m] % mod); }
void solve1()
{
ans[1] = qpow(m, n);
if (n <= m)
{
ans[5] = ans[11] = 1;
ans[2] = 1ll * fac[m] * invf[m - n] % mod;
ans[8] = C(m, n);
}
ans[7] = invf[n];
// printf("done %d\n", ans[7]);
if (n + m - 1 >= m) ans[7] = 1ll * ans[7] * fac[n + m - 1] % mod * invf[m - 1] % mod;
if (n >= m) ans[9] = C(n - 1, m - 1);
}
void solve2()
{
for (int i = 0; i <= min(n, m); i++)
{
F[i] = qpow(i, n) * invf[i] % mod;
G[i] = (i & 1 ? mod - invf[i] : invf[i]);
}
times(F, G, min(n, m) + 1, min(n, m) + 1, min(n, m) + 1);
if (n >= m) ans[3] = 1ll * F[m] * fac[m] % mod;
for (int i = 1; i <= min(n, m); i++) ans[4] = (ans[4] + F[i]) % mod;
ans[6] = (n < m ? 0 : F[m]);
memset(F, 0, sizeof(F)), memset(G, 0, sizeof(G));
}
void solve3()
{
for (int i = 1; i <= m; i++)
for (int j = 0; j <= n; j += i)
F[j] = (F[j] + 1ll * invf[j / i] * fac[j / i - 1] % mod) % mod;
expp(F, G, n + 1);
ans[10] = G[n];
if (n >= m) ans[12] = G[n - m];
}
int main()
{
scanf("%d%d", &n, &m);
init(max(n, m) << 1), solve1(), solve2(), solve3();
// printf("debug %d\n", invf[n]);
for (int i = 1; i <= 12; i++) printf("%d\n", ans[i]);
return 0;
}