CF1039D You Are Given a Tree (树形 dp + 贪心 + 根号分治)
树形 dp + 贪心 + 根号分治
题目是一个经典问题,可以用树形 dp 和贪心解决。设 \(f_u\) 表示以 \(u\) 节点为端点能够剩下的最长路径。考虑从叶子节点往上合并贪心,那么如果能够合并出包含 \(u\) 节点的大于等于 \(k\) 的路径,那么就合并, \(f_u=0\);
否则 \(f_u=\max(f_v+1)\)。
如果每个 \(k\in [1,n]\) 都跑一次,那么复杂度是 \(O(n^2)\)。考虑优化。
我们发现,随着 \(k\) 增大,答案是一定不增的,并且假如当前枚举到 \(k\),那么答案最大为 \(\lfloor\frac{n}{k}\rfloor\)。这样除法的形式可以想到根号分治。一般的,如果 \(k\ge \sqrt{n}\),那么答案 \(\lfloor\frac{n}{k}\rfloor\le\sqrt{n}\),也就是后面的答案取值只有 \(\sqrt{n}\) 种。
结合上答案的单调性,不难想到二分每个段的右端点,这样复杂度为 \(O(n\sqrt{n}\log n)\)。
而对于 \(k\le\sqrt{n}\) 的部分,复杂度为 \(O(n\sqrt{n})\)。我们可以找到一个更好的值使得这两部分尽可能均匀。
设这个值为 \(B\),那么前面的复杂度为 \(O(nB)\),后面的复杂度为 \(O(\frac{n^2}{B}\log n)\),两部分相等,得 \(B=\sqrt{n\log n}\)。
所以总复杂度为 \(O(n\sqrt{n\log n})\)。但是这题很卡常,所以需要一些卡常技巧:
- 链式前向星存边
- 快读快写
- 一遍 dfs 后将树拍成 dfs 序后倒着遍历,这也是最关键的卡常的地方。
#include <bits/stdc++.h>
#define pii std::pair<int, int>
#define mk std::make_pair
#define fi first
#define se second
#define pb push_back
using i64 = long long;
using ull = unsigned long long;
const i64 iinf = 0x3f3f3f3f, linf = 0x3f3f3f3f3f3f3f3f;
const int N = 1e5 + 10;
int n, ans, k, g, tot, cnt;
int dfn[N], fat[N], mx[N], dmx[N];
struct node {
int to, nxt;
} e[N << 1];
int f[N], h[N];
void add(int u, int v) {
e[++cnt].to = v;
e[cnt].nxt = h[u];
h[u] = cnt;
}
void dfs(int u, int fa) {
dfn[++tot] = u;
fat[u] = fa;
for (int i = h[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa) continue;
dfs(v, u);
}
}
void dfs2() {
for(int i = tot; i >= 1; i--) {
int u = dfn[i];
if(f[mx[u]] + f[dmx[u]] + 1 >= k) ans++, f[u] = 0;
else f[u] = f[mx[u]] + 1;
int fa = fat[u];
if(f[mx[fa]] < f[u]) dmx[fa] = mx[fa], mx[fa] = u;
else if(f[dmx[fa]] < f[u]) dmx[fa] = u;
}
}
bool check(int x) {
for(int i = 1; i <= n; i++) f[i] = mx[i] = dmx[i] = 0;
k = x, ans = 0;
dfs2();
if(ans == g) return 1;
return 0;
}
int main() {
scanf("%d", &n);
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
add(u, v), add(v, u);
}
dfs(1, 0);
int n1 = n, m = 0;
while(n1) n1 >>= 1, m++;
m++;
m = sqrt(1LL * n * m);
for (int i = 1; i <= n; i++) {
k = i;
if(i <= m) {
for (int j = 1; j <= n; j++) f[j] = mx[j] = dmx[j] = 0;
dfs2();
printf("%d\n", ans);
} else {
for (int j = 1; j <= n; j++) f[j] = mx[j] = dmx[j] = 0;
dfs2(); g = ans;
int l = i, r = n, mx = i;
while(l <= r) {
int mid = (l + r) >> 1;
if(check(mid)) l = mid + 1, mx = mid;
else r = mid - 1;
}
for (int j = i; j <= mx; j++) printf("%d\n", g);
i = mx;
}
ans = 0;
}
return 0;
}