洛谷 P2590 [ZJOI2008]树的统计
洛谷P2590 [ZJOI2008]树的统计
Solution
树链剖分
算是一道板子题,如果不会树链剖分可以看我的博客 浅谈树链剖分
题目要求我们支持单点修改,查询链上最大值,查询链上和
那么我们线段树就要维护两个东西,一个维护区间和,另一个维护区间最大值
这道题没什么思维难度,也没什么坑点,但是两个 \(query\) 函数都得写双份是真的麻烦QWQ
直接看代码吧
完整代码
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#define N 30010
#define INF 0x3f3f3f3f
#define ls rt << 1
#define rs rt << 1 | 1
using namespace std;
struct node{
int v, nxt;
}edge[N << 1];
int head[N], tot;
int n, m;
int w[N];
int siz[N], fa[N], son[N], dep[N];
int top[N], tw[N], id[N], cnt;
char s[10];
inline int read(){
int x = 0, f = 1;
char ch = getchar();
while(ch < '0' || ch > '9'){
if(ch == '-') f = -1;
ch = getchar();
}
while(ch >= '0' && ch <= '9')
x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
return x * f;
}
inline void add(int x, int y){
edge[++tot] = (node){y, head[x]};
head[x] = tot;
}
void dfs1(int x, int f){
fa[x] = f;
dep[x] = dep[f] + 1;
siz[x] = 1;
for(int i = head[x]; i; i = edge[i].nxt){
int y = edge[i].v;
if(y == f) continue;
dfs1(y, x);
siz[x] += siz[y];
if(!son[x] || siz[y] > siz[son[x]])
son[x] = y;
}
}
void dfs2(int x, int topfa){
id[x] = ++cnt;
top[x] = topfa;
tw[cnt] = w[x];
if(!son[x]) return;
dfs2(son[x], topfa);
for(int i = head[x]; i; i = edge[i].nxt){
int y = edge[i].v;
if(y == fa[x] || y == son[x]) continue;
dfs2(y, y);
}
}
int maxs[N << 2], sum[N << 2];
void pushup(int rt){
maxs[rt] = max(maxs[ls], maxs[rs]);
sum[rt] = sum[ls] + sum[rs];
}
void build(int l, int r, int rt){
if(l == r){
maxs[rt] = sum[rt] = tw[l];
return;
}
int mid = (l + r) >> 1;
build(l, mid, ls);
build(mid + 1, r, rs);
pushup(rt);
}
void update(int x, int k, int l, int r, int rt){
if(l == r){
maxs[rt] = sum[rt] = k;
return;
}
int mid = (l + r) >> 1;
if(x <= mid) update(x, k, l, mid, ls);
else update(x, k, mid + 1, r, rs);
pushup(rt);
}
int query_max(int L, int R, int l, int r, int rt){
if(L <= l && r <= R)
return maxs[rt];
int mid = (l + r) >> 1;
int res = -INF;
if(L <= mid) res = max(res, query_max(L, R, l, mid, ls));
if(R > mid) res = max(res, query_max(L, R, mid + 1, r, rs));
return res;
}
int query_sum(int L, int R, int l, int r, int rt){
if(L <= l && r <= R)
return sum[rt];
int mid = (l + r) >> 1;
int res = 0;
if(L <= mid) res += query_sum(L, R, l, mid, ls);
if(R > mid) res += query_sum(L, R ,mid + 1, r, rs);
return res;
}
int q_max(int x, int y){
int res = -INF;
while(top[x] != top[y]){
if(dep[top[x]] < dep[top[y]]) swap(x, y);
res = max(res, query_max(id[top[x]], id[x], 1, n, 1));
x = fa[top[x]];
}
if(dep[x] > dep[y]) swap(x, y);
res = max(res, query_max(id[x], id[y], 1, n, 1));
return res;
}
int q_sum(int x, int y){
int res = 0;
while(top[x] != top[y]){
if(dep[top[x]] < dep[top[y]]) swap(x, y);
res += query_sum(id[top[x]], id[x], 1, n, 1);
x = fa[top[x]];
}
if(dep[x] > dep[y]) swap(x, y);
res += query_sum(id[x], id[y], 1, n, 1);
return res;
}
int main(){
n = read();
for(int i = 1; i < n; i++){
int u, v;
u = read(), v = read();
add(u, v), add(v, u);
}
for(int i = 1; i <= n; i++)
w[i] = read();
dfs1(1, 0);
dfs2(1, 1);
build(1, n, 1);
m = read();
while(m--){
int x, y;
scanf("%s", s);
x = read(), y = read();
if(s[1] == 'M') printf("%d\n", q_max(x, y));
else if(s[1] == 'S') printf("%d\n", q_sum(x, y));
else update(id[x], y, 1, n, 1);
}
return 0;
}