例题:https://www.luogu.com.cn/problem/P3384

已知一棵包含 \(n\) 个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
1 x y z:表示将树从 \(x\)\(y\) 结点最短路径上所有节点的值都加上 \(z\)
2 x y:表示求树从 \(x\)\(y\) 结点最短路径上所有节点的值之和。
3 x z:表示将以 \(x\) 为根节点的子树内所有节点值都加上 \(z\)
4 x:表示求以 \(x\) 为根节点的子树内所有节点值之和

#include<bits/stdc++.h>
using namespace std;
using LL = long long;
struct HLD{
	vector<vector<int>> e;
	vector<int> top, dep, parent, siz, son, id, a, val;
	int idx, mod;
	HLD(int n, int P){
		mod = P;
		e.resize(n + 1);
		top.resize(n + 1);
		dep.resize(n + 1);
		parent.resize(n + 1);
		siz.resize(n + 1);
		son.resize(n + 1);
		id.resize(n + 1);
		idx = 0;
		a.resize(n + 1);
		val.resize(n + 1);
		tr.resize((n << 2) + 1);
	}
	void add(int u, int v){
		e[u].push_back(v);
		e[v].push_back(u);
	}
	void dfs1(int u){
		siz[u] = 1;
		dep[u] = dep[parent[u]] + 1;
		for (auto v : e[u]){
			if (v == parent[u]) continue;
			parent[v] = u;
			dfs1(v);
			siz[u] += siz[v];
			if (siz[v] > siz[son[u]]) son[u] = v;
		}
	}
	void dfs2(int u, int up){
		id[u] = ++ idx;
		top[u] = up;
		val[idx] = a[u];
		if (son[u]) dfs2(son[u], up);
		for (auto v : e[u]){
			if (v == parent[u] || v == son[u]) continue;
			dfs2(v, v);
		}
	}
	
	struct node{
		int l, r;
		LL sum, add;
	};
	vector<node> tr;
	void pushup(int u){
		tr[u].sum = (tr[u << 1].sum + tr[u << 1 | 1].sum) % mod;
	}
	void pushdown(int u){
		auto &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
		left.add = (left.add + root.add) % mod;
		left.sum = (left.sum + root.add * (left.r - left.l + 1) % mod) % mod;
		right.add = (right.add + root.add) % mod;
		right.sum = (right.sum + root.add * (right.r - right.l + 1) % mod) % mod;
		root.add = 0;
	}
	void build(int u, int l, int r){
		if (l == r){
			tr[u] = {l, r, val[r], 0};
			return;
		}
		tr[u] = {l, r};
		int mid = l + r >> 1;
		build(u << 1, l, mid);
		build(u << 1 | 1, mid + 1, r);
		pushup(u);
	}
	void modify(int u, int l, int r, LL k){
		if (tr[u].l >= l && tr[u].r <= r){
			tr[u].sum = (tr[u].sum + k * (tr[u].r - tr[u].l + 1) % mod) % mod;
			tr[u].add = (tr[u].add + k) % mod;
		}
		else {
			pushdown(u);
			int mid = tr[u].l + tr[u].r >> 1;
			if (l <= mid) modify(u << 1, l, r, k);
			if (r > mid) modify(u << 1 | 1, l, r, k);
			pushup(u);
		}
	}
	LL query(int u, int l, int r){
		if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
		pushdown(u);
		int mid = tr[u].l + tr[u].r >> 1;
		LL sum = 0;
		if (l <= mid) sum = query(u << 1, l, r);
		if (r > mid) sum = (sum + query(u << 1 | 1, l, r)) % mod;
		return sum;
	}
	
	void modifyRange(int u, int v, int k){
		while(top[u] != top[v]){
			if (dep[top[u]] < dep[top[v]]) swap(u, v);
			modify(1, id[top[u]], id[u], k);
			u = parent[top[u]];
		}
		if (dep[u] > dep[v]) swap(u, v);
		modify(1, id[u], id[v], k);
	}
	LL queryRange(LL u, LL v){
		LL ans = 0;
		while(top[u] != top[v]){
			if (dep[top[u]] < dep[top[v]]) swap(u, v);
			ans = (ans + query(1, id[top[u]], id[u])) % mod;
			u = parent[top[u]];
		}
		if (dep[u] > dep[v]) swap(u, v);
		return (ans + query(1, id[u], id[v])) % mod;
	}
	void modifySon(LL u, LL k){
		modify(1, id[u], id[u] + siz[u] - 1, k);
	}
	LL querySon(LL u){
		return query(1, id[u], id[u] + siz[u] - 1) % mod;
	}
};
int main(){
	ios::sync_with_stdio(false);cin.tie(0);
	int n, m, r, p;
	cin >> n >> m >> r >> p;
	HLD t(n, p);
	for (int i = 1; i <= n; i ++ ){
		cin >> t.a[i];
	}
	for (int i = 0; i < n - 1; i ++ ){
		int u, v;
		cin >> u >> v;
		t.add(u, v);
	}
	t.dfs1(r);
	t.dfs2(r, r);
	t.build(1, 1, n);
	
	while (m -- ){
		int op, x, y, z;
		cin >> op >> x;
		if (op == 1){
			cin >> y >> z;
			t.modifyRange(x, y, z);
		}
		else if (op == 2){
			cin >> y;
			cout << t.queryRange(x, y) << "\n";
		}
		else if (op == 3){
			cin >> z;
			t.modifySon(x, z);
		}
		else{
			cout << t.querySon(x) << "\n";
		}
	}
	return 0;
}
posted on 2022-08-28 16:55  Hamine  阅读(52)  评论(0编辑  收藏  举报