例题:https://www.luogu.com.cn/problem/P3384
已知一棵包含 \(n\) 个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
1 x y z:表示将树从 \(x\) 到 \(y\) 结点最短路径上所有节点的值都加上 \(z\)。
2 x y:表示求树从 \(x\) 到 \(y\) 结点最短路径上所有节点的值之和。
3 x z:表示将以 \(x\) 为根节点的子树内所有节点值都加上 \(z\)。
4 x:表示求以 \(x\) 为根节点的子树内所有节点值之和
#include<bits/stdc++.h>
using namespace std;
using LL = long long;
struct HLD{
vector<vector<int>> e;
vector<int> top, dep, parent, siz, son, id, a, val;
int idx, mod;
HLD(int n, int P){
mod = P;
e.resize(n + 1);
top.resize(n + 1);
dep.resize(n + 1);
parent.resize(n + 1);
siz.resize(n + 1);
son.resize(n + 1);
id.resize(n + 1);
idx = 0;
a.resize(n + 1);
val.resize(n + 1);
tr.resize((n << 2) + 1);
}
void add(int u, int v){
e[u].push_back(v);
e[v].push_back(u);
}
void dfs1(int u){
siz[u] = 1;
dep[u] = dep[parent[u]] + 1;
for (auto v : e[u]){
if (v == parent[u]) continue;
parent[v] = u;
dfs1(v);
siz[u] += siz[v];
if (siz[v] > siz[son[u]]) son[u] = v;
}
}
void dfs2(int u, int up){
id[u] = ++ idx;
top[u] = up;
val[idx] = a[u];
if (son[u]) dfs2(son[u], up);
for (auto v : e[u]){
if (v == parent[u] || v == son[u]) continue;
dfs2(v, v);
}
}
struct node{
int l, r;
LL sum, add;
};
vector<node> tr;
void pushup(int u){
tr[u].sum = (tr[u << 1].sum + tr[u << 1 | 1].sum) % mod;
}
void pushdown(int u){
auto &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
left.add = (left.add + root.add) % mod;
left.sum = (left.sum + root.add * (left.r - left.l + 1) % mod) % mod;
right.add = (right.add + root.add) % mod;
right.sum = (right.sum + root.add * (right.r - right.l + 1) % mod) % mod;
root.add = 0;
}
void build(int u, int l, int r){
if (l == r){
tr[u] = {l, r, val[r], 0};
return;
}
tr[u] = {l, r};
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void modify(int u, int l, int r, LL k){
if (tr[u].l >= l && tr[u].r <= r){
tr[u].sum = (tr[u].sum + k * (tr[u].r - tr[u].l + 1) % mod) % mod;
tr[u].add = (tr[u].add + k) % mod;
}
else {
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) modify(u << 1, l, r, k);
if (r > mid) modify(u << 1 | 1, l, r, k);
pushup(u);
}
}
LL query(int u, int l, int r){
if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
LL sum = 0;
if (l <= mid) sum = query(u << 1, l, r);
if (r > mid) sum = (sum + query(u << 1 | 1, l, r)) % mod;
return sum;
}
void modifyRange(int u, int v, int k){
while(top[u] != top[v]){
if (dep[top[u]] < dep[top[v]]) swap(u, v);
modify(1, id[top[u]], id[u], k);
u = parent[top[u]];
}
if (dep[u] > dep[v]) swap(u, v);
modify(1, id[u], id[v], k);
}
LL queryRange(LL u, LL v){
LL ans = 0;
while(top[u] != top[v]){
if (dep[top[u]] < dep[top[v]]) swap(u, v);
ans = (ans + query(1, id[top[u]], id[u])) % mod;
u = parent[top[u]];
}
if (dep[u] > dep[v]) swap(u, v);
return (ans + query(1, id[u], id[v])) % mod;
}
void modifySon(LL u, LL k){
modify(1, id[u], id[u] + siz[u] - 1, k);
}
LL querySon(LL u){
return query(1, id[u], id[u] + siz[u] - 1) % mod;
}
};
int main(){
ios::sync_with_stdio(false);cin.tie(0);
int n, m, r, p;
cin >> n >> m >> r >> p;
HLD t(n, p);
for (int i = 1; i <= n; i ++ ){
cin >> t.a[i];
}
for (int i = 0; i < n - 1; i ++ ){
int u, v;
cin >> u >> v;
t.add(u, v);
}
t.dfs1(r);
t.dfs2(r, r);
t.build(1, 1, n);
while (m -- ){
int op, x, y, z;
cin >> op >> x;
if (op == 1){
cin >> y >> z;
t.modifyRange(x, y, z);
}
else if (op == 2){
cin >> y;
cout << t.queryRange(x, y) << "\n";
}
else if (op == 3){
cin >> z;
t.modifySon(x, z);
}
else{
cout << t.querySon(x) << "\n";
}
}
return 0;
}