洛谷 P2590 [ZJOI2008]树的统计

洛谷P2590 [ZJOI2008]树的统计

原题链接

Solution

树链剖分

算是一道板子题,如果不会树链剖分可以看我的博客 浅谈树链剖分

题目要求我们支持单点修改,查询链上最大值,查询链上和

那么我们线段树就要维护两个东西,一个维护区间和,另一个维护区间最大值

这道题没什么思维难度,也没什么坑点,但是两个 \(query\) 函数都得写双份是真的麻烦QWQ

直接看代码吧

完整代码

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define N 30010
#define INF 0x3f3f3f3f
#define ls rt << 1
#define rs rt << 1 | 1

using namespace std;

struct node{
	int v, nxt;
}edge[N << 1];
int head[N], tot;
int n, m;
int w[N];
int siz[N], fa[N], son[N], dep[N];
int top[N], tw[N], id[N], cnt;
char s[10];

inline int read(){
	int x = 0, f = 1;
	char ch = getchar();
	while(ch < '0' || ch > '9'){
		if(ch == '-') f = -1;
		ch = getchar();
	}
	while(ch >= '0' && ch <= '9')
		x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
	return x * f;
}

inline void add(int x, int y){
	edge[++tot] = (node){y, head[x]};
	head[x] = tot;
}

void dfs1(int x, int f){
	fa[x] = f;
	dep[x] = dep[f] + 1;
	siz[x] = 1;
	for(int i = head[x]; i; i = edge[i].nxt){
		int y = edge[i].v;
		if(y == f) continue;
		dfs1(y, x);
		siz[x] += siz[y];
		if(!son[x] || siz[y] > siz[son[x]])
			son[x] = y;
	}
}

void dfs2(int x, int topfa){
	id[x] = ++cnt;
	top[x] = topfa;
	tw[cnt] = w[x];
	if(!son[x]) return;
	dfs2(son[x], topfa);
	for(int i = head[x]; i; i = edge[i].nxt){
		int y = edge[i].v;
		if(y == fa[x] || y == son[x]) continue;
		dfs2(y, y);
	}
}

int maxs[N << 2], sum[N << 2];

void pushup(int rt){
	maxs[rt] = max(maxs[ls], maxs[rs]);
	sum[rt] = sum[ls] + sum[rs];
}

void build(int l, int r, int rt){
	if(l == r){
		maxs[rt] = sum[rt] = tw[l];
		return;
	}
	int mid = (l + r) >> 1;
	build(l, mid, ls);
	build(mid + 1, r, rs);
	pushup(rt);
}

void update(int x, int k, int l, int r, int rt){
	if(l == r){
		maxs[rt] = sum[rt] = k;
		return;
	}
	int mid = (l + r) >> 1;
	if(x <= mid) update(x, k, l, mid, ls);
	else update(x, k, mid + 1, r, rs);
	pushup(rt);
}

int query_max(int L, int R, int l, int r, int rt){
	if(L <= l && r <= R)
		return maxs[rt];
	int mid = (l + r) >> 1;
	int res = -INF;
	if(L <= mid) res = max(res, query_max(L, R, l, mid, ls));
	if(R > mid) res = max(res, query_max(L, R, mid + 1, r, rs));
	return res;
}

int query_sum(int L, int R, int l, int r, int rt){
	if(L <= l && r <= R)
		return sum[rt];
	int mid = (l + r) >> 1;
	int res = 0;
	if(L <= mid) res += query_sum(L, R, l, mid, ls);
	if(R > mid) res += query_sum(L, R ,mid + 1, r, rs);
	return res;
}

int q_max(int x, int y){
	int res = -INF;
	while(top[x] != top[y]){
		if(dep[top[x]] < dep[top[y]]) swap(x, y);
		res = max(res, query_max(id[top[x]], id[x], 1, n, 1));
		x = fa[top[x]];
	}
	if(dep[x] > dep[y]) swap(x, y);
	res = max(res, query_max(id[x], id[y], 1, n, 1));
	return res;
}

int q_sum(int x, int y){
	int res = 0;
	while(top[x] != top[y]){
		if(dep[top[x]] < dep[top[y]]) swap(x, y);
		res += query_sum(id[top[x]], id[x], 1, n, 1);
		x = fa[top[x]];
	}
	if(dep[x] > dep[y]) swap(x, y);
	res += query_sum(id[x], id[y], 1, n, 1);
	return res;
}

int main(){
	n = read();
	for(int i = 1; i < n; i++){
		int u, v;
		u = read(), v = read();
		add(u, v), add(v, u);
	}
	for(int i = 1; i <= n; i++)
		w[i] = read();
	dfs1(1, 0);
	dfs2(1, 1);
	build(1, n, 1);
	m = read();
	while(m--){
		int x, y;
		scanf("%s", s);
		x = read(), y = read();
		if(s[1] == 'M') printf("%d\n", q_max(x, y));
		else if(s[1] == 'S') printf("%d\n", q_sum(x, y));
		else update(id[x], y, 1, n, 1);
	}
	return 0;
}
posted @ 2021-08-06 15:20  xixike  阅读(34)  评论(0编辑  收藏  举报