Snow的追寻(线段树)(LCA)
Snow的追寻
题目大意
给你一棵树,每次规定两个子树不能到,问你树上的最长路径长度。
思路
看到有关子树,考虑用 dfs 序来搞。
而且一般这种子树的操作会用到线段树?
考虑用线段树维护,维护 \(l\sim r\) 区间的点能形成的最长路径。
这样子的话,我们可以把题目要求不能有两个子树里面的点得到剩下的点,用线段树拿出那三段,然后再用线段树的合并方法合并起来,最后得到的值就是答案。
然后考虑如何合并。
考虑线段树维护这一条路径的长度,以及两段的两个点。
那你合并的时候,就两条路径四个点,不难想到新的路径的两个端点一定是这四个之中的两个。
那我们可以直接暴力看没两个点的匹配情况(求两点之间路径长度用 LCA 求),然后选最大的那个。
然后就可以啦。
代码
#include<cstdio>
#include<iostream>
#include<algorithm>
using namespace std;
const int N = 100005;
struct node {
int to, nxt;
}e[N << 2];
int n, q, x, y, le[N], KK, up[N], ans, tmp;
int fa[N][21], deg[N], dfn[N], ed[N], dy[N];
void add(int x, int y) {
e[++KK] = (node){y, le[x]}; le[x] = KK;
e[++KK] = (node){x, le[y]}; le[y] = KK;
}
void dfs(int now, int father) {
deg[now] = deg[father] + 1;
fa[now][0] = father;
dfn[now] = ++tmp;
dy[tmp] = now;
for (int i = le[now]; i; i = e[i].nxt)
if (e[i].to != father) {
dfs(e[i].to, now);
}
ed[now] = tmp;
}
int LCA(int x, int y) {//LCA 求路径长度
if (deg[x] < deg[y]) swap(x, y);
for (int i = 20; i >= 0; i--)
if (deg[fa[x][i]] >= deg[y]) x = fa[x][i];
if (x == y) return x;
for (int i = 20; i >= 0; i--)
if (fa[x][i] != fa[y][i])
x = fa[x][i], y = fa[y][i];
return fa[x][0];
}
int get_dis(int x, int y) {
if (!x || !y) return 0;
int z = LCA(x, y);
return deg[x] + deg[y] - 2 * deg[z];
}
struct XDtree {//线段树
struct node {
int val, fir, sec;
}a[N << 2], ans;
void merge(node &x, node y, node z) {
int a, b, c, d, e;
a = get_dis(y.fir, z.fir);
b = get_dis(y.fir, z.sec);
c = get_dis(y.sec, z.fir);
d = get_dis(y.sec, z.sec);
e = max(max(a, b), max(c, d));
x.val = e;//四个点里面任选两个匹配得到最长路径
if (e == a) x.fir = y.fir, x.sec = z.fir;
if (e == b) x.fir = y.fir, x.sec = z.sec;
if (e == c) x.fir = y.sec, x.sec = z.fir;
if (e == d) x.fir = y.sec, x.sec = z.sec;
if (x.val < y.val) x.val = y.val, x.fir = y.fir, x.sec = y.sec;
if (x.val < z.val) x.val = z.val, x.fir = z.fir, x.sec = z.sec;
if (!x.val) x.fir = x.sec = 0;
}
void build(int now, int l, int r) {
if (l == r) {
a[now].fir = a[now].sec = dy[l];
a[now].val = 0; return ;
}
int mid = (l + r) >> 1;
build(now << 1, l, mid);
build(now << 1 | 1, mid + 1, r);
merge(a[now], a[now << 1], a[now << 1 | 1]);
}
void find(int now, int l, int r, int L, int R) {
if (L > R) return ;
if (L <= l && r <= R) {
merge(ans, ans, a[now]);
return ;
}
int mid = (l + r) >> 1;
if (L <= mid) find(now << 1, l, mid, L, R);
if (mid < R) find(now << 1 | 1, mid + 1, r, L, R);
}
}T;
int main() {
// freopen("snow.in", "r", stdin);
// freopen("snow.out", "w", stdout);
scanf("%d %d", &n, &q);
for (int i = 1; i < n; i++) {
scanf("%d %d", &x, &y);
add(x, y);
}
dfs(1, 0);
for (int i = 1; i <= 20; i++)
for (int j = 1; j <= n; j++)
fa[j][i] = fa[fa[j][i - 1]][i - 1];
T.build(1, 1, n);
while (q--) {
scanf("%d %d", &x, &y);
if (dfn[y] < dfn[x]) swap(x, y);
T.ans.val = T.ans.fir = T.ans.sec = 0;
T.find(1, 1, n, 1, dfn[x] - 1);//分成三段合并入答案
T.find(1, 1, n, ed[x] + 1, dfn[y] - 1);
if (ed[y] >= ed[x]) T.find(1, 1, n, ed[y] + 1, n);
else T.find(1, 1, n, ed[x] + 1, n);
printf("%d\n", T.ans.val);
}
fclose(stdin);
fclose(stdout);
return 0;
}