洛谷P2590 [ZJOI2008]树的统计 题解 树链剖分+线段树
题目链接:https://www.luogu.org/problem/P2590
树链剖分模板题。
剖分过程要用到如下7个值:
fa[u]
:u的父节点编号;dep[u]
:u的深度;size[u]
:u为根的子树中节点总数;son[u]
:u的重儿子;top[u]
:u所在的重链的顶部节点;seg[u]
:u在线段树中的位置;rev[u]
:seg的倒置,即rev[seg[u]] == u
。
然后套线段树模板实现区间最值、区间和,及单点更新操作。
实现代码如下:
#include <bits/stdc++.h>
using namespace std;
#define INF (1<<29)
const int maxn = 30030;
int fa[maxn],
dep[maxn],
size[maxn],
son[maxn],
top[maxn],
seg[maxn], seg_cnt,
rev[maxn],
n, w[maxn], maxv[maxn<<2], sumv[maxn<<2];
vector<int> g[maxn];
void dfs1(int u, int p) {
size[u] = 1;
for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
int v = (*it);
if (v == p) continue;
fa[v] = u;
dep[v] = dep[u] + 1;
dfs1(v, u);
size[u] += size[v];
if (size[v] >size[son[u]]) son[u] = v;
}
}
void dfs2(int u, int tp) {
seg[u] = ++seg_cnt;
rev[seg_cnt] = u;
top[u] = tp;
if (son[u]) dfs2(son[u], tp);
for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
int v = (*it);
if (v == fa[u] || v == son[u]) continue;
dfs2(v, v);
}
}
#define lson l, mid, rt<<1
#define rson mid+1, r, rt<<1|1
void push_up(int rt) {
sumv[rt] = sumv[rt<<1] +sumv[rt<<1|1];
maxv[rt] = max(maxv[rt<<1], maxv[rt<<1|1]);
}
void build(int l, int r, int rt) {
int mid = (l + r) / 2;
if (l == r) {
sumv[rt] = maxv[rt] = w[rev[l]];
return;
}
build(lson); build(rson);
push_up(rt);
}
void update(int p, int v, int l, int r, int rt) {
if (l == r) {
sumv[rt] = maxv[rt] = v;
return;
}
int mid = (l + r) / 2;
if (p <= mid) update(p, v, lson);
else update(p, v, rson);
push_up(rt);
}
int query_max(int L, int R, int l, int r, int rt) {
if (L <= l && r <= R) return maxv[rt];
int mid = (l + r) / 2, tmp = -INF;
if (L <= mid) tmp = max(tmp, query_max(L, R, lson));
if (R > mid) tmp = max(tmp, query_max(L, R, rson));
return tmp;
}
int query_sum(int L, int R, int l, int r, int rt) {
if (L <= l && r <= R) return sumv[rt];
int mid = (l + r) / 2, tmp = 0;
if (L <= mid) tmp += query_sum(L, R, lson);
if (R > mid) tmp += query_sum(L, R, rson);
return tmp;
}
int ask_max(int u, int v) {
int res = -INF;
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(u, v);
res = max(res, query_max(seg[top[u]], seg[u], 1, n, 1));
u = fa[top[u]];
}
if (dep[u] < dep[v]) swap(u, v);
res = max(res, query_max(seg[v], seg[u], 1, n, 1));
return res;
}
int ask_sum(int u, int v) {
int res = 0;
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(u, v);
res += query_sum(seg[top[u]], seg[u], 1, n, 1);
u = fa[top[u]];
}
if (dep[u] < dep[v]) swap(u, v);
res += query_sum(seg[v], seg[u], 1, n, 1);
return res;
}
int m;
string s;
int main() {
cin >> n;
for (int i = 1; i < n; i ++) {
int u, v;
cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
for (int i = 1; i <= n; i ++) cin >> w[i];
dep[1] = fa[1] = 1;
dfs1(1, -1);
dfs2(1, 1);
build(1, n, 1);
cin >> m;
while (m --) {
int u, v;
cin >> s >> u >> v;
if (s == "CHANGE") update(seg[u], v, 1, n, 1);
else if (s == "QMAX") cout << ask_max(u, v) << endl;
else cout << ask_sum(u, v) << endl;
}
return 0;
}