「CEOI2017」Mousetrap
方便起见,我们把陷阱点作为这棵树的根节点,那么老鼠进入陷阱的过程就是从某一个节点往上跳父亲的过程。
注意到老鼠要尽可能的拖延时间,那么它肯定想走到子树里去。
思考发现,老鼠要是进入一棵子树,那么最后一定会被卡在这个子树的一个叶子上面。
而且不难发现老鼠一定是先自己往根走几步,然后再找个子树钻进去。
不难发现如果老鼠被卡在叶子上,最好的做法就是把这个叶子到根的路径的分岔全部堵死,最后再把这条路径扫干净。
老鼠进入某个子树再被赶回来的最小步数是可以 \(\text{DP}\) 出来的:
\(dp_u = \text{2ndmax}\{dp_v\} + son_u\),其中 \(v\) 为 \(u\) 的儿子,\(son_u\) 为 \(u\) 的儿子数。
如果我们再用 \(f_u\) 来表示 \(u\) 到根的路径上的分岔(我们把 \(u\) 连向儿子的边也算进来),那么 \(f_u = f_{fa_u} + son_u - 1\)
继而我们可以得到,如果老鼠在 \(u\) 点钻入了 \(v\) 的子树,此时的最小步数就是 \(f_u + dp_v - [u \ne s]\) (\(s\) 表示老鼠的起点)
好像我们只要把起点到根的路径上每一个位置都试一次就可以了?其实不然。
因为在老鼠往上走的时候,我们并不知道我们能不能堵出我们想要的局面。
所以我们可以考虑二分一个最小步数 \(mid\) ,那么我们就必须把老鼠可能达成的最后答案值大于 \(mid\) 的路全都堵上,如果堵的过程中发现步数不够就不合法。
参考代码:
#include <cstdio>
int min(int a, int b) { return a < b ? a : b; }
int max(int a, int b) { return a > b ? a : b; }
template < class T > void read(T& s) {
s = 0; int f = 0; char c = getchar();
while ('0' > c || c > '9') f |= c == '-', c = getchar();
while ('0' <= c && c <= '9') s = s * 10 + c - 48, c = getchar();
s = f ? -s : s;
}
typedef long long LL;
const int _ = 1e6 + 5;
int tot, head[_]; struct Edge { int v, nxt; } edge[_ << 1];
void Add_edge(int u, int v) { edge[++tot] = (Edge) { v, head[u] }, head[u] = tot; }
int n, rt, s, dp[_], dis[_], fa[_];
void dfs(int u, int f) {
fa[u] = f;
int cnt = 0;
for (int i = head[u]; i; i = edge[i].nxt)
if (edge[i].v != f) ++cnt;
dis[u] = dis[f] - 1 + cnt;
int mx = 0, mmx = 0;
for (int i = head[u]; i; i = edge[i].nxt) {
int v = edge[i].v; if (v == f) continue ;
dfs(v, u);
if (dp[v] >= mx) mmx = mx, mx = dp[v];
else if (dp[v] >= mmx) mmx = dp[v];
}
dp[u] = mmx + cnt;
}
int check(int mid) {
int sum = 0, x = 1;
for (int i = s, pre = 0; i; ++x, pre = i, i = fa[i]) {
int tmp = 0;
for (int j = head[i]; j; j = edge[j].nxt) {
int v = edge[j].v; if (v == pre || v == fa[i]) continue ;
if (dp[v] + dis[i] + 1 - (i != s) > mid) ++tmp;
}
sum += tmp, mid -= tmp;
if (sum > x || mid < 0) return 0;
}
return 1;
}
int main() {
#ifndef ONLINE_JUDGE
freopen("cpp.in", "r", stdin), freopen("cpp.out", "w", stdout);
#endif
read(n), read(rt), read(s);
for (int u, v, i = 1; i < n; ++i)
read(u), read(v), Add_edge(u, v), Add_edge(v, u);
dfs(rt, 0);
int l = 0, r = n - 1;
while (l < r) {
int mid = (l + r) >> 1;
if (check(mid)) r = mid; else l = mid + 1;
}
printf("%d\n", l);
return 0;
}