CF1613F 题解

CF1613F 题解

题意:

给一棵以 \(1\) 为根的树,统计排列 \(P\) 的个数,要求满足以下条件:

\(\forall u > 1, P_u \neq P_{fa_u} - 1\),其中 \(fa_u\)\(u\) 在树上的父亲节点。

做法:

首先可以考虑容斥,即对于每个 \(i \in [0,n-1]\) 都计算出至少有 \(i\) 对点不满足条件的排列数量,

记这个数为 \(f_i\),则最终答案就等于 \(\sum _ {i = 0} ^ {n-1} (-1) ^ i f_i\),那么问题转化为如何计算 \(f\) 数组的值。

我们发现,一个点的所有儿子中,最多只有一个儿子与该点属于一对不满足条件的点,否则:

设有两对点 \((u,v),(u,w)\) 不满足条件,则有 \(P_v = P_u-1 = P_w\),不满足 \(P\) 是排列的前提。

故我们将每条连接一对不满足条件点的边标记,最终所有标记边形成若干条不交的链。

对于每条链,只要有一个在链上的点在 \(P\) 中的值确定了,整条链上点在 \(P\) 中的值就都确定了。

也就是说,我们可以将每条链都分别缩成一个点,在最后形成的新树中,假设还剩余 \(t\) 个点,

则任意一个长为 \(t\) 的排列都可以作为一个合法的方案,统计进 \(f_{n-t}\) 里去。

也就是说,我们只需要对于每个 \(i \in [0,n-1]\),计算出至少缩了 \(i\) 个点的方案数 \(g_i\)

再乘上 \((n-i)!\) 就等于 \(f_i\)

我们可以得到一个式子: \(g_i = [x^i]\prod _ {u = 1} ^ {n} (d_ux + 1)\),其中 \(d_u\) 代表点 \(u\) 的儿子数量。

式子成立的原因是,对于每个点 \(u\),可以选择其是否与一个儿子组成一对不满足条件的点,

两种情况的方案数分别是 \(d_u\)\(1\),而第一种情况会导致一个点被缩进链中,

故最后在多项式 \(\prod _ {u = 1} ^ {n} (d_ux + 1)\)\(x^i\) 项的系数,就是至少缩了 \(i\) 个点的方案数。

至于为什么可以将所有点的方案相乘,是因为任意两个点的方案独立,即互相不会造成影响。

那么问题就转化为计算多项式 \(\prod _ {u = 1} ^ {n} (d_ux + 1)\) 中每一项的系数。

一种时间复杂度为 \(O(n\log^2n)\) 的做法是分治,即:

我们记多项式 \(s(l,r) = \prod _ {u=l} ^ {r} (d_ux + 1)\),那么有 \(s(l,r) = s(l,mid) * s(mid+1,r)\)

其中 \(mid = \lfloor \frac {l + r} {2} \rfloor\)\(h = f*g\) 代表 \(h\)\(f\)\(g\) 的卷积。

故我们考虑计算 \(s(l,r)\) 的系数表示时,可以先计算 \(s(l,mid)\)\(s(mid+1,r)\) 的系数表示,

再做一次多项式乘法即可。

这样的时间复杂度是 \(T(n) = 2T(\frac {n} {2}) + O(n \log n)\),即 \(T(n) = O(n \log ^2 n)\)

但是我们还有一种更优秀的 \(O(n \log n)\) 的做法。

考虑如下恒等式:

\(\prod _ {u=1} ^ {n} (d_ux + 1) = \prod _ {k=1} ^ {n} (kx+1)^{\sum _ {u=1} ^ {n} [d_u = k]}\),即将所有儿子数量相等的点的方案合并。

我们记 \(cnt_k = \sum _ {u=1} ^ {n} [d_u = k]\),接下来我们可以尝试将 \((kx + 1) ^ {cnt_k}\) 展开,即:

\(h(k) = (kx + 1) ^ {cnt_k} = \sum _ {i = 0} ^ {cnt_k} \binom {cnt_k} {i} k ^ i x ^ i\)

也就是说,我们可以用总复杂度 \(O(n)\) 的时间,得到所有 \(k\) 对应 \(h(k)\) 的系数表示。

我们发现,如果我们倒序枚举 \(k\),并依次将 \(h(k)\) 卷在一起,这样的复杂度是可以接受的。

具体来说,我们记初始多项式 \(A = 1\),然后倒序枚举 \(k\)

对于每个 \(k\),我们都做一次 \(A = A * h(k)\) ,最后的 \(A\) 的系数表示就是我们要的结果。

而这样的复杂度是:

\(T(n)=O(\sum_{k = 1}^n \sum_{u=1}^{n}[d_u \geq k]\log n)\)

我们交换求和顺序,即:

\(\sum _ {k = 1} ^ {n} \sum _ {u = 1} ^ {n} [d_u \ge k] = \sum _ {u = 1} ^ n \sum _ {k = 1} ^ n [d_u \ge k] = \sum _ {u = 1} ^ {n} d_u = O(n)\)

故最终的时间复杂度 \(T(n) = O(n \log n)\)

下面给出 \(O(n \log ^ 2 n)\) 做法的代码:

#define LL long long
#define vi vector < LL >
#define rep(i, a, b) for (int i = (a); i <= (b); i++)
#define per(i, a, b) for (int i = (a); i >= (b); i--)
using namespace std;
const int N (1e6 + 10);
const LL G1 (3), G2 (332748118), mod (998244353);
LL ksm (LL a, int b, LL r = 1) {
	for (; b; b >>= 1, a = a * a % mod) if (b & 1) r = r * a % mod;
	return r;
}
LL ans, fac[N];
int n, m, d[N], rv[N];
void NTT (bool sgn, vi &f) {
	rep (i, 0, m - 1) if (i < rv[i]) swap (f[i], f[rv[i]]);
	for (int len = 2; len <= m; len <<= 1) {
		int l = len / 2; LL dt = ksm (sgn ? G2 : G1, (mod - 1) / len), w = 1;
		for (int st = 0; st < m; st += len, w = 1) rep (i, st, st + l - 1) {
			LL fl = f[i], fr = f[i + l] * w % mod;
			f[i] = (fl + fr) % mod, f[i + l] = (fl - fr + mod) % mod, (w *= dt) %= mod;
		}
	} if (sgn) {
		LL inv = ksm (m, mod - 2);
		rep (i, 0, m - 1) f[i] = f[i] * inv % mod;
	}
}
vi solve (int l, int r) {
	if (l == r) return {1ll,  1ll * d[l]};
	int mid = (l + r) >> 1;
	vi FL = solve (l, mid), FR = solve (mid + 1, r);
	for (m = 1; m <= r - l + 1; m <<= 1) ;
	rep (i, 0, m - 1) rv[i] = (rv[i >> 1] >> 1) | ((i & 1) ? (m >> 1) : 0);
	FL.resize (m), FR.resize (m), NTT (0, FL), NTT (0, FR);
	rep (i, 0, m - 1) FL[i] = FL[i] * FR[i] % mod;
	return NTT (1, FL), FL;
}
signed main() {
	fac[0] = fac[1] = 1;
	cin >> n; LL sgn = 1;
	rep (i, 2, n) {
		int u , v; cin >> u >> v;
		d[i]--, d[u]++, d[v]++, fac[i] = fac[i - 1] * 1ll * i % mod;
	} vi F = solve (1, n);
	rep (i, 0, n - 1) {
		sgn = (i & 1) ? (-1) : 1;
		ans = (ans + sgn * (fac[n - i] * F[i] % mod) + mod) % mod;
	} return cout << ans, 0;
}
posted @ 2021-12-03 16:19  GaryH  阅读(157)  评论(0编辑  收藏  举报