Loading

CF1039D You Are Given a Tree (树形 dp + 贪心 + 根号分治)

CF1039D You Are Given a Tree

树形 dp + 贪心 + 根号分治

题目是一个经典问题,可以用树形 dp 和贪心解决。设 \(f_u\) 表示以 \(u\) 节点为端点能够剩下的最长路径。考虑从叶子节点往上合并贪心,那么如果能够合并出包含 \(u\) 节点的大于等于 \(k\) 的路径,那么就合并, \(f_u=0\)

否则 \(f_u=\max(f_v+1)\)

如果每个 \(k\in [1,n]\) 都跑一次,那么复杂度是 \(O(n^2)\)。考虑优化。

我们发现,随着 \(k\) 增大,答案是一定不增的,并且假如当前枚举到 \(k\),那么答案最大为 \(\lfloor\frac{n}{k}\rfloor\)。这样除法的形式可以想到根号分治。一般的,如果 \(k\ge \sqrt{n}\),那么答案 \(\lfloor\frac{n}{k}\rfloor\le\sqrt{n}\),也就是后面的答案取值只有 \(\sqrt{n}\)

结合上答案的单调性,不难想到二分每个段的右端点,这样复杂度为 \(O(n\sqrt{n}\log n)\)

而对于 \(k\le\sqrt{n}\) 的部分,复杂度为 \(O(n\sqrt{n})\)。我们可以找到一个更好的值使得这两部分尽可能均匀。

设这个值为 \(B\),那么前面的复杂度为 \(O(nB)\),后面的复杂度为 \(O(\frac{n^2}{B}\log n)\),两部分相等,得 \(B=\sqrt{n\log n}\)

所以总复杂度为 \(O(n\sqrt{n\log n})\)。但是这题很卡常,所以需要一些卡常技巧:

  1. 链式前向星存边
  2. 快读快写
  3. 一遍 dfs 后将树拍成 dfs 序后倒着遍历,这也是最关键的卡常的地方。
#include <bits/stdc++.h>
#define pii std::pair<int, int>
#define mk std::make_pair
#define fi first
#define se second
#define pb push_back

using i64 = long long;
using ull = unsigned long long;
const i64 iinf = 0x3f3f3f3f, linf = 0x3f3f3f3f3f3f3f3f;
const int N = 1e5 + 10;
int n, ans, k, g, tot, cnt;
int dfn[N], fat[N], mx[N], dmx[N];
struct node {
	int to, nxt;
} e[N << 1];
int f[N], h[N];
void add(int u, int v) {
	e[++cnt].to = v;
	e[cnt].nxt = h[u];
	h[u] = cnt;
}
void dfs(int u, int fa) { 
	dfn[++tot] = u; 
	fat[u] = fa; 
	for (int i = h[u]; i; i = e[i].nxt) {
		int v = e[i].to;
		if(v == fa) continue;
		dfs(v, u);
	}
}
void dfs2() {
	for(int i = tot; i >= 1; i--) { 
		int u = dfn[i];
		if(f[mx[u]] + f[dmx[u]] + 1 >= k) ans++, f[u] = 0;
		else f[u] = f[mx[u]] + 1;

		int fa = fat[u];
		if(f[mx[fa]] < f[u]) dmx[fa] = mx[fa], mx[fa] = u;
		else if(f[dmx[fa]] < f[u]) dmx[fa] = u; 
	}
}
bool check(int x) {
	for(int i = 1; i <= n; i++) f[i] = mx[i] = dmx[i] = 0;
	k = x, ans = 0;
	dfs2();
	if(ans == g) return 1;
	return 0;
}
int main() {
	scanf("%d", &n);
	for (int i = 1; i < n; i++) {
		int u, v;
		scanf("%d%d", &u, &v);
		add(u, v), add(v, u);
	}	

	dfs(1, 0);

	int n1 = n, m = 0;
	while(n1) n1 >>= 1, m++;
	m++;
	m = sqrt(1LL * n * m);

	for (int i = 1; i <= n; i++) {
		k = i;
		if(i <= m) {
			for (int j = 1; j <= n; j++) f[j] = mx[j] = dmx[j] = 0;
			dfs2();
			printf("%d\n", ans);
		} else {
			for (int j = 1; j <= n; j++) f[j] = mx[j] = dmx[j] = 0;
			dfs2(); g = ans;
			int l = i, r = n, mx = i;
			while(l <= r) {
				int mid = (l + r) >> 1;
				if(check(mid)) l = mid + 1, mx = mid;
				else r = mid - 1;
			}
			for (int j = i; j <= mx; j++) printf("%d\n", g);
			i = mx;
		}
		ans = 0;
	}

	return 0;
}
posted @ 2024-07-04 10:17  Fire_Raku  阅读(24)  评论(0编辑  收藏  举报