CF1613F 题解
CF1613F 题解
题意:
给一棵以 \(1\) 为根的树,统计排列 \(P\) 的个数,要求满足以下条件:
\(\forall u > 1, P_u \neq P_{fa_u} - 1\),其中 \(fa_u\) 为 \(u\) 在树上的父亲节点。
做法:
首先可以考虑容斥,即对于每个 \(i \in [0,n-1]\) 都计算出至少有 \(i\) 对点不满足条件的排列数量,
记这个数为 \(f_i\),则最终答案就等于 \(\sum _ {i = 0} ^ {n-1} (-1) ^ i f_i\),那么问题转化为如何计算 \(f\) 数组的值。
我们发现,一个点的所有儿子中,最多只有一个儿子与该点属于一对不满足条件的点,否则:
设有两对点 \((u,v),(u,w)\) 不满足条件,则有 \(P_v = P_u-1 = P_w\),不满足 \(P\) 是排列的前提。
故我们将每条连接一对不满足条件点的边标记,最终所有标记边形成若干条不交的链。
对于每条链,只要有一个在链上的点在 \(P\) 中的值确定了,整条链上点在 \(P\) 中的值就都确定了。
也就是说,我们可以将每条链都分别缩成一个点,在最后形成的新树中,假设还剩余 \(t\) 个点,
则任意一个长为 \(t\) 的排列都可以作为一个合法的方案,统计进 \(f_{n-t}\) 里去。
也就是说,我们只需要对于每个 \(i \in [0,n-1]\),计算出至少缩了 \(i\) 个点的方案数 \(g_i\),
再乘上 \((n-i)!\) 就等于 \(f_i\)。
我们可以得到一个式子: \(g_i = [x^i]\prod _ {u = 1} ^ {n} (d_ux + 1)\),其中 \(d_u\) 代表点 \(u\) 的儿子数量。
式子成立的原因是,对于每个点 \(u\),可以选择其是否与一个儿子组成一对不满足条件的点,
两种情况的方案数分别是 \(d_u\) 和 \(1\),而第一种情况会导致一个点被缩进链中,
故最后在多项式 \(\prod _ {u = 1} ^ {n} (d_ux + 1)\) 中 \(x^i\) 项的系数,就是至少缩了 \(i\) 个点的方案数。
至于为什么可以将所有点的方案相乘,是因为任意两个点的方案独立,即互相不会造成影响。
那么问题就转化为计算多项式 \(\prod _ {u = 1} ^ {n} (d_ux + 1)\) 中每一项的系数。
一种时间复杂度为 \(O(n\log^2n)\) 的做法是分治,即:
我们记多项式 \(s(l,r) = \prod _ {u=l} ^ {r} (d_ux + 1)\),那么有 \(s(l,r) = s(l,mid) * s(mid+1,r)\),
其中 \(mid = \lfloor \frac {l + r} {2} \rfloor\),\(h = f*g\) 代表 \(h\) 是 \(f\) 和 \(g\) 的卷积。
故我们考虑计算 \(s(l,r)\) 的系数表示时,可以先计算 \(s(l,mid)\) 和 \(s(mid+1,r)\) 的系数表示,
再做一次多项式乘法即可。
这样的时间复杂度是 \(T(n) = 2T(\frac {n} {2}) + O(n \log n)\),即 \(T(n) = O(n \log ^2 n)\)。
但是我们还有一种更优秀的 \(O(n \log n)\) 的做法。
考虑如下恒等式:
\(\prod _ {u=1} ^ {n} (d_ux + 1) = \prod _ {k=1} ^ {n} (kx+1)^{\sum _ {u=1} ^ {n} [d_u = k]}\),即将所有儿子数量相等的点的方案合并。
我们记 \(cnt_k = \sum _ {u=1} ^ {n} [d_u = k]\),接下来我们可以尝试将 \((kx + 1) ^ {cnt_k}\) 展开,即:
\(h(k) = (kx + 1) ^ {cnt_k} = \sum _ {i = 0} ^ {cnt_k} \binom {cnt_k} {i} k ^ i x ^ i\)。
也就是说,我们可以用总复杂度 \(O(n)\) 的时间,得到所有 \(k\) 对应 \(h(k)\) 的系数表示。
我们发现,如果我们倒序枚举 \(k\),并依次将 \(h(k)\) 卷在一起,这样的复杂度是可以接受的。
具体来说,我们记初始多项式 \(A = 1\),然后倒序枚举 \(k\),
对于每个 \(k\),我们都做一次 \(A = A * h(k)\) ,最后的 \(A\) 的系数表示就是我们要的结果。
而这样的复杂度是:
\(T(n)=O(\sum_{k = 1}^n \sum_{u=1}^{n}[d_u \geq k]\log n)\)。
我们交换求和顺序,即:
\(\sum _ {k = 1} ^ {n} \sum _ {u = 1} ^ {n} [d_u \ge k] = \sum _ {u = 1} ^ n \sum _ {k = 1} ^ n [d_u \ge k] = \sum _ {u = 1} ^ {n} d_u = O(n)\)。
故最终的时间复杂度 \(T(n) = O(n \log n)\)。
下面给出 \(O(n \log ^ 2 n)\) 做法的代码:
#define LL long long
#define vi vector < LL >
#define rep(i, a, b) for (int i = (a); i <= (b); i++)
#define per(i, a, b) for (int i = (a); i >= (b); i--)
using namespace std;
const int N (1e6 + 10);
const LL G1 (3), G2 (332748118), mod (998244353);
LL ksm (LL a, int b, LL r = 1) {
for (; b; b >>= 1, a = a * a % mod) if (b & 1) r = r * a % mod;
return r;
}
LL ans, fac[N];
int n, m, d[N], rv[N];
void NTT (bool sgn, vi &f) {
rep (i, 0, m - 1) if (i < rv[i]) swap (f[i], f[rv[i]]);
for (int len = 2; len <= m; len <<= 1) {
int l = len / 2; LL dt = ksm (sgn ? G2 : G1, (mod - 1) / len), w = 1;
for (int st = 0; st < m; st += len, w = 1) rep (i, st, st + l - 1) {
LL fl = f[i], fr = f[i + l] * w % mod;
f[i] = (fl + fr) % mod, f[i + l] = (fl - fr + mod) % mod, (w *= dt) %= mod;
}
} if (sgn) {
LL inv = ksm (m, mod - 2);
rep (i, 0, m - 1) f[i] = f[i] * inv % mod;
}
}
vi solve (int l, int r) {
if (l == r) return {1ll, 1ll * d[l]};
int mid = (l + r) >> 1;
vi FL = solve (l, mid), FR = solve (mid + 1, r);
for (m = 1; m <= r - l + 1; m <<= 1) ;
rep (i, 0, m - 1) rv[i] = (rv[i >> 1] >> 1) | ((i & 1) ? (m >> 1) : 0);
FL.resize (m), FR.resize (m), NTT (0, FL), NTT (0, FR);
rep (i, 0, m - 1) FL[i] = FL[i] * FR[i] % mod;
return NTT (1, FL), FL;
}
signed main() {
fac[0] = fac[1] = 1;
cin >> n; LL sgn = 1;
rep (i, 2, n) {
int u , v; cin >> u >> v;
d[i]--, d[u]++, d[v]++, fac[i] = fac[i - 1] * 1ll * i % mod;
} vi F = solve (1, n);
rep (i, 0, n - 1) {
sgn = (i & 1) ? (-1) : 1;
ans = (ans + sgn * (fac[n - i] * F[i] % mod) + mod) % mod;
} return cout << ans, 0;
}