[ZJOI2007] 捉迷藏
在大佬的 blog 上看到的树上距离问题转化为括号序列求解的技巧。
就是在 dfs 的时候,某个节点入栈时加入左括号,出栈时加入右括号。
那么对于树上两点的距离,就是他们中间未匹配的括号数量。因为匹配的括号必定不存在于他们之间的路径上,其他的都存在于他们的路径上。
然后这道题我们维护树上两个黑点之间未匹配的括号数的最大值。
然后用线段树进行维护即可。
#include <bits/stdc++.h>
#define reg register
#define ll long long
using namespace std;
const int MAXN = 5e5 + 10;
const int INF = 0x3f3f3f3f;
int n, pos[MAXN], dfn[MAXN], clo = 0;
int E, head[MAXN], nxt[MAXN << 1], pnt[MAXN << 1];
inline void clear() {
E = 0;
memset(head, -1, sizeof(head));
}
inline void addedge(int x, int y) {
nxt[E] = head[x];
pnt[E] = y;
head[x] = E++;
}
inline void dfs(int u, int f) {
dfn[++clo] = -2;
dfn[pos[u] = ++clo] = u;
for(reg int i = head[u]; i != -1; i = nxt[i]) {
int v = pnt[i];
if(v != f)
dfs(v, u);
}
dfn[++clo] = -1;
}
int a[MAXN << 2], b[MAXN << 2], lp[MAXN << 2], rp[MAXN << 2], lm[MAXN << 2];
int rm[MAXN << 2], ret[MAXN << 2];
/*
a表示当前区间右括号数量,b表示当前区间左括号数量
rp表示当前区间内的一个黑点到它右端右括号和左括号加起来的最大值
rm表示当前区间内的一个黑点到它右端右括号比左括号多的数量的最大值
lp,lm相反
*/
inline void pushup(int x) {
a[x] = a[x << 1] + max(a[x << 1 | 1] - b[x << 1], 0);
b[x] = b[x << 1 | 1] + max(b[x << 1] - a[x << 1 | 1], 0);
rp[x] = max(rp[x << 1 | 1], max(rp[x << 1] - a[x << 1 | 1] + b[x << 1 | 1], rm[x << 1] + a[x << 1 | 1] + b[x << 1 | 1]));
rm[x] = max(rm[x << 1 | 1], rm[x << 1] + a[x << 1 | 1] - b[x << 1 | 1]);
lp[x] = max(lp[x << 1], max(lp[x << 1 | 1] + a[x << 1] - b[x << 1], lm[x << 1 | 1] + a[x << 1] + b[x << 1]));
lm[x] = max(lm[x << 1], lm[x << 1 | 1] - a[x << 1] + b[x << 1]);
ret[x] = max(max(ret[x << 1], ret[x << 1 | 1]), max(rp[x << 1] + lm[x << 1 | 1], rm[x << 1] + lp[x << 1 | 1]));
}
inline void build(int l, int r, int x) {
if(l == r) {
if(dfn[l] > 0)
lp[x] = rp[x] = lm[x] = rm[x] = ret[x] = 0;
else {
lp[x] = rp[x] = lm[x] = rm[x] = -INF;
ret[x] = -1;
}
if(dfn[l] == -2)
b[x] = 1;
if(dfn[l] == -1)
a[x] = 1;
return;
}
int mid = (l + r) >> 1;
build(l, mid, x << 1);
build(mid + 1, r, x << 1 | 1);
pushup(x);
}
inline void modify(int l, int r, int k, int x) {
if(l == r) {
if(lp[x] > -INF) {
lp[x] = rp[x] = lm[x] = rm[x] = -INF;
ret[x] = -1;
}
else
lp[x] = rp[x] = lm[x] = rm[x] = ret[x] = 0;
return;
}
int mid = (l + r) >> 1;
if(k <= mid)
modify(l, mid, k, x << 1);
if(k > mid)
modify(mid + 1, r, k, x << 1 | 1);
pushup(x);
}
inline void work() {
scanf("%d", &n);
clear();
for(reg int i = 1; i < n; ++i) {
int u, v;
scanf("%d%d", &u, &v);
addedge(u, v);
addedge(v, u);
}
dfs(1, -1);
build(1, clo, 1);
int m;
scanf("%d", &m);
for(reg int i = 1; i <= m; ++i) {
char op = getchar();
while(!isupper(op))
op = getchar();
if(op == 'C') {
int x;
scanf("%d", &x);
modify(1, clo, pos[x], 1);
}
else
printf("%d\n", ret[1]);
}
}
int main() {
work();
return 0;
}