洛谷P2146 树链剖分
思路:直接树链剖分,用线段树维护即可,算是树剖的经典题目吧。
代码:
#include <bits/stdc++.h> #define ls(x) (x << 1) #define rs(x) ((x << 1) | 1) using namespace std; const int maxn = 100010; int head[maxn], Next[maxn * 2], ver[maxn * 2]; int sz[maxn], son[maxn], d[maxn], dfn[maxn], top[maxn], f[maxn]; int tot, cnt; int n; struct SegmentTree { int val, lz; int l, r; }; SegmentTree tr[maxn * 4]; void add(int x, int y) { ver[++tot] = y; Next[tot] = head[x]; head[x] = tot; } void dfs1(int x, int fa = -1) { sz[x] = 1; f[x] = fa; int mx = 0; for (int i = head[x]; i; i = Next[i]) { int y = ver[i]; if(y == fa) continue; d[y] = d[x] + 1; dfs1(y, x); sz[x] += sz[y]; if(sz[y] > mx) { mx = sz[y]; son[x] = y; } } } void dfs2(int x, int fa, int t) { dfn[x] = ++cnt; top[x] = t; if(son[x]) dfs2(son[x], x, t); for (int i = head[x]; i; i = Next[i]) { int y = ver[i]; if(y == fa || y == son[x]) continue; dfs2(y, x, y); } } void pushup(int o) { tr[o].val = tr[ls(o)].val + tr[rs(o)].val; } void maintain(int o, int val) { tr[o].val = val * (tr[o].r - tr[o].l + 1); tr[o].lz = val; } void pushdown(int o) { if(tr[o].lz != -1) { maintain(ls(o), tr[o].lz); maintain(rs(o), tr[o].lz); tr[o].lz = -1; } } void build(int o, int l, int r) { tr[o].l = l, tr[o].r = r; if(l == r) { tr[o].val = 0; tr[o].lz = -1; return; } int mid = (l + r) >> 1; build(ls(o), l, mid); build(rs(o), mid + 1, r); pushup(o); } void update(int o, int l, int r, int ql, int qr, int val) { if(l >= ql && r <= qr) { tr[o].val = (r - l + 1) * val; tr[o].lz = val; return; } pushdown(o); int mid = (l + r) >> 1; if(ql <= mid) update(ls(o), l, mid, ql, qr, val); if(qr > mid) update(rs(o), mid + 1, r, ql, qr, val); pushup(o); } int query(int o, int l, int r, int ql, int qr) { if(l >= ql && r <= qr) { return tr[o].val; } pushdown(o); int mid = (l + r) >> 1, ans = 0; if(ql <= mid) ans += query(ls(o), l, mid , ql, qr); if(qr > mid) ans += query(rs(o), mid + 1, r, ql, qr); return ans; } int solve(int x) { int ans = 0, st = x; while(x != -1) { ans += query(1, 1, n, dfn[top[x]], dfn[x]); x = f[top[x]]; } return d[st] - d[0] + 1 - ans; } void update1(int x, int val) { while(x != -1) { update(1, 1, n, dfn[top[x]], dfn[x], val); x = f[top[x]]; } } char s[110]; int main() { int x, m; scanf("%d", &n); for (int i = 1; i < n; i++) { scanf("%d", &x); add(x, i); add(i, x); } f[0] = -1; build(1, 1, n); dfs1(0); dfs2(0, -1, 0); scanf("%d", &m); while(m--) { scanf("%s", s + 1); if(s[1] == 'i') { scanf("%d", &x); int tmp = query(1, 1, n, dfn[x], dfn[x]); if(tmp == 1) { printf("0\n"); continue; } printf("%d\n", solve(x)); update1(x, 1); } else { scanf("%d", &x); int tmp = query(1, 1, n, dfn[x], dfn[x]); if(tmp == 0) { printf("0\n"); continue; } printf("%d\n", query(1, 1, n, dfn[x], dfn[x] + sz[x] - 1)); update(1, 1, n, dfn[x], dfn[x] + sz[x] - 1, 0); } } }