Codeforces Gym102538(300iq contest 3)A. Airplane Cliques
给定一棵 \(n\) 个节点的树。定义树上两点距离为它们之间边的数量。
称一对节点是友好的,当且仅当两点之间距离小于等于 \(x\)。
称一个 \(k\) 个节点的集合是友好集合,当且仅当集合中任意两个节点都是友好的。
请对所有 \(k=1\ldots n\),求出恰有 \(k\) 个节点的友好集合数量。
\(1\le n\le 3\cdot 10^5,0\le x\le n-1\)。
首先考虑如何快速判定一个集合是友好集合。两两枚举显然不可取,有一个结论是,两两枚举取到的最大值中的其中一个节点是点集中深度最大的点(以树上任意一个节点为根)。因此考虑按照 \(\text{bfs}\) 序加入每个点,计算当前这个点及之前的点组成的友好集合数。
那么对于每个按照 \(\text{bfs}\) 序排好的节点 \(p_i\),要求出前面的与 \(p_i\) 距离 \(\le x\) 的节点 \(p_j(j<i)\) 的数量。
设 \(cnt_i\) 表示这样的 \(p_j\) 的数量,答案为 \(\sum_{i=1}^{n}\binom{cnt_i}{k-1}\),这一部分令 \(c_i=\sum_{j=1}^{n} [cnt_j=i]\),可以化作差卷积的形式,用 \(\text{NTT}\) 可以 \(O(n\log n)\) 解决。
那么关键问题就是快速求出所有 \(cnt_i\)。这个跟 \(\text{Luogu P6329}\) 几乎一样。
运用到点分树的重要性质:
- 点分树上的两个节点 \(p,q\) 的 \(\text{lca}\) 一定在原树 \(p,q\) 的路径上。
- 点分树的树高为 \(O(\log n)\) 级别。
此题中运用这两个性质,暴力跳 \(p\) 在点分树上的祖先,只需计算 \(\operatorname{dis}(p,\text{lca})+\operatorname{dis}(q,\text{lca})\le x\) 的 \(q\) 的数量。注意到这样计算会重复计数,此时在点分树上减去包含 \(p\) 的子树方向的 \(q\) 的数量即可。统计时使用树状数组,复杂度 \(O(n\log^2 n)\)。
注意对每个点开树状数组时不能开到 \(n\),否则空间复杂度会变为 \(O(n^2)\),开到子树内最大深度即可,根据点分树的性质,空间复杂度为 \(O(n\log n)\)。
#include <bits/stdc++.h>
#define eb emplace_back
using namespace std;
typedef long long ll;
typedef vector<int> Poly;
const int N = 2097152, mod = 998244353;
int n, cnt, dis, root, totsz; ll fac[N], inv[N], Inv[N];
int a[N], f[N], sz[N], fa[N], dep[N], mxd[N], t[N], r[N];
int seq[N], Dep[N], anc[N][20];
bool vis[N]; vector<int> G[N];
struct Fenwick_Tree {
vector<int> c;
int sz;
inline void add(int x, int y) {
++ x; assert(x >= 1 && x <= sz);
for (; x <= sz; x += x & -x) c[x] += y;
}
inline int ask(int x) {
if (x < 0) return 0;
++ x, x = min(x, sz);
assert(x >= 1 && x <= sz);
int ans = 0;
for (; x; x ^= x & -x) ans += c[x];
return ans;
}
inline void init(int _sz) {
c.resize(_sz + 5), sz = _sz;
}
Fenwick_Tree(){}
}btr[N], btr_fa[N];
inline int add(int x, int y) { return x + y >= mod ? x + y - mod : x + y; }
inline int dec(int x, int y) { return x - y < 0 ? x - y + mod : x - y; }
inline int mul(int x, int y) { return (ll)x * y % mod; }
inline int qpow(int a, int b) {
int ans = 1;
for (; b; b >>= 1, a = mul(a, a)) if (b & 1) ans = mul(ans, a);
return ans;
}
inline void NTT(Poly &A, int len, int type){
for (int i = 0; i < len; ++ i) if(i < r[i]) swap(A[i], A[r[i]]);
for (int mid = 1; mid < len; mid <<= 1) {
int Wn = qpow(type == 1 ? 3 : (mod + 1) / 3, (mod - 1) / (mid << 1));
for (int j = 0; j < len; j += (mid << 1))
for (int k = 0, w = 1; k < mid; ++ k, w = mul(w, Wn)) {
int x = A[j + k], y = mul(w, A[j + k + mid]);
A[j + k] = add(x, y), A[j + k + mid] = dec(x, y);
}
}
if (type > 0) return;
for (int i = 0; i < len; ++ i) A[i] = mul(A[i], Inv[len]);
}
inline void init_rev(int len) {
for (int i = 0; i < len; ++ i) r[i] = r[i >> 1] >> 1 | ((i & 1) * (len >> 1));
}
inline void get_root(int x, int fa) {
f[x] = 0, sz[x] = 1;
for (auto y : G[x]) {
if (vis[y] || y == fa) continue;
get_root(y, x);
f[x] = max(f[x], sz[y]);
sz[x] += sz[y];
}
f[x] = max(f[x], totsz - sz[x]);
if (!root || f[x] < f[root]) root = x;
}
inline void dfs(int x, int fa) {
sz[x] = 1, dep[x] = dep[fa] + 1;
mxd[x] = dep[x];
for (auto y : G[x]) {
if (vis[y] || y == fa) continue;
dfs(y, x);
sz[x] += sz[y];
mxd[x] = max(mxd[x], mxd[y]);
}
}
inline void build_dividetree(int x) {
vis[x] = 1;
dfs(x, 0);
btr[x].init(mxd[x]);
for (auto y : G[x]) {
if (vis[y]) continue;
root = 0, totsz = sz[y];
get_root(y, 0);
fa[root] = x;
btr_fa[root].init(mxd[y] + 5);
build_dividetree(root);
}
}
queue<int> q;
inline void bfs() {
q.push(1);
memset(vis, 0, sizeof(bool) * (n + 1));
while (!q.empty()) {
int x = q.front(); q.pop();
vis[x] = 1, seq[++cnt] = x;
for (auto y : G[x]) {
if (vis[y]) continue;
q.push(y);
}
}
}
inline void precalc(int n) {
fac[0] = inv[0] = Inv[0] = fac[1] = inv[1] = Inv[1] = 1;
for (int i = 2; i <= n; ++ i)
fac[i] = fac[i - 1] * i % mod,
Inv[i] = (mod - mod / i) * Inv[mod % i] % mod,
inv[i] = inv[i - 1] * Inv[i] % mod;
}
inline void Dfs(int x, int fa) {
Dep[x] = Dep[fa] + 1, anc[x][0] = fa;
for (auto y : G[x]) {
if (y == fa) continue;
Dfs(y, x);
}
}
inline int lca(int x, int y) {
if (Dep[x] < Dep[y]) swap(x, y);
for (int i = 0; i <= 19; ++ i)
if ((Dep[x] - Dep[y]) >> i & 1)
x = anc[x][i];
if (x == y) return x;
for (int i = 19; ~i; -- i)
if (anc[x][i] ^ anc[y][i])
x = anc[x][i], y = anc[y][i];
return anc[x][0];
}
inline int dist(int x, int y) {
int u = lca(x, y);
return Dep[x] + Dep[y] - Dep[u] * 2;
}
int main() {
precalc(N - 1);
scanf("%d%d", &n, &dis);
for (int i = 1, u, v; i < n; ++ i) {
scanf("%d%d", &u, &v);
G[u].eb(v), G[v].eb(u);
}
root = 0, totsz = n, get_root(1, 0);
build_dividetree(root);
bfs(); Dfs(1, 0);
for (int j = 1; j <= 19; ++ j)
for (int i = 1; i <= n; ++ i)
anc[i][j] = anc[anc[i][j - 1]][j - 1];
for (int i = 1; i <= n; ++ i) {
int x = seq[i];
a[x] = btr[x].ask(dis);
for (int now = fa[x], son = x; now; son = now, now = fa[now]) {
int D = dist(now, x);
a[x] += btr[now].ask(dis - D) - btr_fa[son].ask(dis - D);
}
for (int now = x, son = 0; now; son = now, now = fa[now]) {
int D = dist(now, x);
btr[now].add(D, 1);
if (now != x) btr_fa[son].add(D, 1);
}
}
Poly F, P; F.resize(n), P.resize(n);
for (int i = 1; i <= n; ++ i) ++ t[a[i]];
for (int i = 0; i < n; ++ i) F[i] = mul(t[i], fac[i]), P[i] = inv[i];
reverse(F.begin(), F.end());
int lim = 1; while (lim < n + n - 1) lim <<= 1; init_rev(lim);
F.resize(lim), P.resize(lim);
NTT(F, lim, 1), NTT(P, lim, 1);
for (int i = 0; i < lim; ++ i) F[i] = mul(F[i], P[i]);
NTT(F, lim, -1);
for (int i = 1; i <= n; ++ i) printf("%d ", mul(F[n - i], inv[i - 1]));
return 0;
}