【题解】CF997C Sky Full of Stars
简记一下高维二项式反演的套路。
思路
高维二项式反演。
首先意识到 \(n \leq 10^6\) 且计数,并且求 “至少”,所以考虑用二项式反演处理。
这里如果用一维的二项式反演,可能要把行和列看成物品然后合并去重,比较麻烦,可以直接上二维二项式反演。
二维二项式反演结论:
\(F(N, M) = \sum\limits_{i = n}^N \sum\limits_{j = m}^{M} {N \choose i} {M \choose j} G(i, j) \Leftrightarrow G(N, M) = \sum\limits_{i = n}^N \sum\limits_{j = m}^{M} (-1)^{(N - i) + (M - j)} {N \choose i} {M \choose j} F(i, j)\).
具体证明考虑套两次二项式反演或者容斥系数。
所以考虑二项式反演的计数套路。
令 \(F(i, j)\) 为钦定有 \(i\) 行 \(j\) 列同色,其余任意染色的方案总数,\(G(i, j)\) 为刚好有 \(i\) 行 \(j\) 列同色的方案总数。
考虑求出 \(F\) 再通过二维二项式反演求出 \(G\).
\(F\) 因为同色具有传递性,行和列又可能相交,所以需要分类讨论:
-
\(i, j \neq 0\)
此时钦定的行和列都必定是同一种颜色。
\(F(i, j) = 3 {N \choose i} {N \choose j} 3^{(N - i)(N - j)}\).
-
\(i \neq 0, j = 0\) 或 \(i = 0, j \neq 0\)
此时两种情况是对称的,求一种就行。
当 \(i \neq 0, j = 0\) 时 \(F(i, j) = {N \choose i} 3^{i} 3^{N (N - i)}\).
-
\(i = j = 0\)
此时任意染色,总数为 \(3^{n^2}\).
考虑代入二项式反演的式子。
- 当贡献函数为分段函数时,考虑将 \(\sum\) 拆成对每段求和。
于是分讨一下:
- \(i, j \neq 0\) 的贡献
\(\sum\limits_{i = 1}^N \sum\limits_{j = 1}^N (-1)^{i + j} 3 {N \choose i} {N \choose j} 3^{(N - i)(N - j)}\).
把幂次拆开整理得 \(3^{N^2 + 1} \sum\limits_{i = 1}^N {N \choose i} (-1)^i 3^{-iN} \sum\limits_{j = 1}^N {N \choose j} (-1)^j 3^{-jN} 3^{ij}\).
这里凑二项式定理的步骤比较仙。
\(3^{N^2 + 1} \sum\limits_{i = 1}^N {N \choose i} (-1)^i 3^{-iN} \sum\limits_{j = 1}^N {N \choose j} (-1)^j 3^{j(i - N)}\).
因为 \(\sum\limits_{j = 1}^N {N \choose j} (-1)^j 3^{j(i - N)} = \sum\limits_{j = 1}^N {N \choose j} 1^{N - j} ((-3)^{i - N})^j = (1 - 3^{i - N})^N - 1\).
于是代入化简得贡献为 \(3^{N^2 + 1} \sum\limits_{i = 1}^N {N \choose i} (-1)^i 3^{-iN} ((1 - 3^{i - N})^N - 1)\).
- \(i \neq 0, j = 0\) 的贡献
很容易推出是 \(2 \cdot 3^{N^2} ((1 - 3^{1 - N})^N - 1)\)
- \(i = j = 0\) 的贡献
\(3^{N^2}\).
直接快速幂大力做。
代码
#include <cstdio>
using namespace std;
typedef long long ll;
const int maxn = 1e6 + 5;
const int mod = 998244353;
int n;
int fac[maxn], invf[maxn];
ll qpow(ll base, ll power)
{
ll res = 1;
base = (base + mod) % mod;
while (power)
{
if (power & 1) res = res * base % mod;
base = base * base % mod, power >>= 1;
}
return res;
}
void init(int lim)
{
fac[0] = invf[0] = fac[1] = invf[1] = 1;
for (int i = 2; i <= lim; i++) fac[i] = 1ll * fac[i - 1] * i % mod, invf[i] = 1ll * (mod - mod / i) * invf[mod % i] % mod;
for (int i = 2; i <= lim; i++) invf[i] = 1ll * invf[i - 1] * invf[i] % mod;
}
int C(int n, int m) { return 1ll * fac[n] * invf[m] % mod * invf[n - m] % mod; }
int main()
{
scanf("%d", &n);
init(n);
int ans = 0;
for (int i = 1; i <= n; i++)
{
int coe = C(n, i);
int pw1 = qpow(3, 1ll * (mod - 1 - i) * n % (mod - 1));
int pw2 = (qpow(1 - qpow(3, mod - 1 + i - n), n) - 1 + mod) % mod;
int val = 1ll * coe * pw1 % mod * pw2 % mod;
if (i & 1) ans = (ans - val + mod) % mod;
else ans = (ans + val) % mod;
}
ans = 1ll * ans * qpow(3, (1ll * n * n + 1) % (mod - 1)) % mod;
// printf("debug %lld\n", ans);
int coe = (qpow(1 - qpow(3, mod - n), n) - 1 + mod) % mod;
int sum = 2ll * qpow(3, 1ll * n * n % (mod - 1)) % mod % mod;
ans = (ans + 1ll * coe * sum % mod) % mod;
printf("%d\n", (mod - ans) % mod);
return 0;
}