树链剖分
推荐博客:https://www.cnblogs.com/ivanovcraft/p/9019090.html
前置知识:
dfs序,线段树
主要应用:树上有关问题的维护,将书上问题转化为序列问题从而用线段树进行统计维护
大概过程:
1,dfs1计算Size[x]数组(表示x这个树的大小),d数组(表示结点深度),son[x]数组(表示x的子结点中最大的结点),far[x]数组(表示x结点的父亲结点)
void dfs1(int u,int f,int dep)///dfs1指在处理d数组,son数组,far数组,Size数组 { d[u] = dep; far[u] = f; Size[u] = 1; son[u] = -1; for(int i = head[u]; i; i = Next[i]){ int v = ver[i]; if(v == f) continue; dfs1(v,u,dep+1); Size[u] += Size[v]; if(son[u] == -1 || Size[son[u]] < Size[v]) son[u] = v; } }
2,dfs2计算dfs序列,旨在处理好重链
void dfs2(int u,int T)///旨在处理重链,和dfs序列 { dfn[++ cnt] = u;id[u] = cnt; top[u] = T; if(son[u] == -1) return ; dfs2(son[u],T); for(int i = head[u]; i; i = Next[i]){ int v = ver[i]; if(v != son[u] && v != far[u]){ dfs2(v,v); } } }
通过这两个dfs,我们就可以将树上的每个结点转化为序列上的结点了。
例题:https://loj.ac/problem/10138
不知道出于什么想法,一开始一直觉得不用建树,然后wa wawawawa
#include "stdio.h" #include "string.h" #include "algorithm" using namespace std; const int N = 60010, M = 70010; const int INF = -3001010; int n, q; int head[N], ver[N], Next[N], tot; ///树的结构存储 int val[N]; ///存储每个结点的信息 int d[N], son[N], far[N], Size[N]; ///结点的深度,重儿子,祖先 int a[N], maxx[N * 4], sum[N * 4]; ///线段树上的结点值,maxx,sum值 int dfn[N], top[N], id[N]; ///存储dfs序,top是条链的祖先,id是每个结点在dfn中序列的下标位置 int cnt; ///表示的是dfs序列的最后一个位置 void add(int x, int y) { ///添加树边 ver[++tot] = y; Next[tot] = head[x]; head[x] = tot; } void Build_Tree(int id, int l, int r) { if (l == r) { sum[id] = maxx[id] = val[dfn[l]]; return; } int mid = (l + r) >> 1; Build_Tree(id * 2, l, mid); Build_Tree(id * 2 + 1, mid + 1, r); sum[id] = sum[id * 2] + sum[id * 2 + 1]; maxx[id] = max(maxx[id * 2], maxx[id * 2 + 1]); return; } void Update(int id, int l, int r, int loc, int x) ///将loc上的值进行更新 { if (l == r) { maxx[id] = x; sum[id] = x; return; } int mid = (l + r) >> 1; if (loc <= mid) Update(id * 2, l, mid, loc, x); else Update(id * 2 + 1, mid + 1, r, loc, x); maxx[id] = max(maxx[id * 2], maxx[id * 2 + 1]); sum[id] = 0; if (sum[id * 2] != INF) sum[id] += sum[id * 2]; if (sum[id * 2 + 1] != INF) sum[id] += sum[id * 2 + 1]; } int Query_sum(int id, int L, int R, int l, int r) ///查询[l,r]区间和 { if (L > r || R < l) return 0; if (l <= L && r >= R) { return sum[id]; } int mid = (L + R) >> 1; int ans = Query_sum(id * 2, L, mid, l, r) + Query_sum(id * 2 + 1, mid + 1, R, l, r); return ans; } int Query_maxx(int id, int L, int R, int l, int r) ///查询区间[l,r]之间的最大值 { if (l <= L && r >= R) return maxx[id]; int mid = (L + R) >> 1; int ans = -3 * 10010; if (l <= mid) ans = max(ans, Query_maxx(id * 2, L, mid, l, r)); if (r > mid) ans = max(ans, Query_maxx(id * 2 + 1, mid + 1, R, l, r)); return ans; } void dfs1(int u, int f, int dep) /// dfs1指在处理d数组,son数组,far数组,Size数组 { d[u] = dep; far[u] = f; Size[u] = 1; son[u] = -1; for (int i = head[u]; i; i = Next[i]) { int v = ver[i]; if (v == f) continue; dfs1(v, u, dep + 1); Size[u] += Size[v]; if (son[u] == -1 || Size[son[u]] < Size[v]) son[u] = v; } } void dfs2(int u, int T) ///旨在处理重链,和dfs序列 { dfn[++cnt] = u; id[u] = cnt; top[u] = T; if (son[u] == -1) return; dfs2(son[u], T); for (int i = head[u]; i; i = Next[i]) { int v = ver[i]; if (v != son[u] && v != far[u]) { dfs2(v, v); } } } int main() { scanf("%d", &n); for (int i = 1; i < n; i++) { int x, y; scanf("%d%d", &x, &y); add(x, y); add(y, x); } cnt = 0; dfs1(1, 0, 1); dfs2(1, 1); for (int i = 1; i <= n; i++) { scanf("%d", &val[i]); } Build_Tree(1, 1, n); scanf("%d", &q); while (q--) { char str[15]; int u, v; scanf("%s%d%d", str, &u, &v); if (!strcmp(str, "QMAX")) ///求最大值 { int fu = top[u], fv = top[v]; int ans = -300100; while (fu != fv) { if (d[fu] >= d[fv]) { ans = max(ans, Query_maxx(1, 1, n, id[fu], id[u])); u = far[fu]; fu = top[u]; } else { ans = max(ans, Query_maxx(1, 1, n, id[fv], id[v])); v = far[fv]; fv = top[v]; } } if (id[u] <= id[v]) ans = max(ans, Query_maxx(1, 1, n, id[u], id[v])); else ans = max(ans, Query_maxx(1, 1, n, id[v], id[u])); printf("%d\n", ans); } if (!strcmp(str, "QSUM")) ///求和 { int fu = top[u], fv = top[v]; int ans = 0; while (fu != fv) { if (d[fu] >= d[fv]) { ans += Query_sum(1, 1, n, id[fu], id[u]); u = far[fu]; fu = top[u]; } else { ans += Query_sum(1, 1, n, id[fv], id[v]); v = far[fv]; fv = top[v]; } } if (id[u] < id[v]) ans += Query_sum(1, 1, cnt, id[u], id[v]); else ans += Query_sum(1, 1, cnt, id[v], id[u]); printf("%d\n", ans); } if (!strcmp(str, "CHANGE")) ///更新指定u位置上的值更新为v { Update(1, 1, cnt, id[u], v); } } } //