模板-树链剖分
因为要做bzoj1036这个题,所以学习一下树链剖分。
/**************************************************************
Problem: 1036
User: MrMorning
Language: C++
Result: Accepted
Time:2868 ms
Memory:7332 kb
****************************************************************/
#include <bits/stdc++.h>
using namespace std;
const int maxn = 30006;
//=====================
vector<int> G[maxn];
vector<int> son[maxn];
int weigh[maxn];
int size[maxn];
int vis[maxn];
int depth[maxn];
int fa[maxn];
int next[maxn];
int top[maxn];
int id[maxn];
void build_tree(int root);
void build(int k, int l, int r);
void dfs(int root, int);
void change(int x, int y);
int querysum(int k, int x, int y);
int querymx(int k, int x, int y);
int path_sum(int u, int v);
int path_max(int u, int v);
void solve();
int p, x;
int n;
int nz = 0;
int value[maxn];
struct dat {
int l, r, sum, mx;
} seg[100005];
//=================
int main() {
// freopen("tree.in", "r", stdin);
// freopen("tree.out", "w", stdout);
scanf("%d", &n);
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d %d", &u, &v);
G[u].push_back(v);
G[v].push_back(u);
}
for (int i = 1; i <= n; i++)
cin >> weigh[i];
solve();
return 0;
}
void build(int k, int l, int r) {
seg[k].l = l;
seg[k].r = r;
if (l == r)
return;
int mid = (l + r) >> 1;
build(k << 1, l, mid);
build(k << 1 | 1, mid + 1, r);
return;
}
void change(int k, int x, int y) {
int l = seg[k].l, r = seg[k].r, mid = (l + r) >> 1;
if (l == r) {
seg[k].sum = seg[k].mx = y;
return;
}
if (x <= mid)
change(k << 1, x, y);
else
change(k << 1 | 1, x, y);
seg[k].sum = seg[k << 1].sum + seg[k << 1 | 1].sum;
seg[k].mx = max(seg[k << 1].mx, seg[k << 1 | 1].mx);
}
void solve() {
build_tree(1);
dfs(1, 1);
build(1, 1, n);
for (int i = 1; i <= n; i++) {
change(1, id[i], weigh[i]);
}
int t;
cin >> t;
while (t--) {
int p, u, v;
char command[10];
scanf("%s", command);
scanf("%d %d", &u, &v);
if (command[1] == 'M')
printf("%d\n", path_max(u, v));
else if (command[1] == 'S')
printf("%d\n", path_sum(u, v));
else
change(1, id[u], v);
}
}
void build_tree(int root) { // dfs1
vector<int>::iterator it;
size[root] = 1;
vis[root] = 1;
int Max = -1;
int ans = -1;
for (it = G[root].begin(); it != G[root].end(); it++)
if (!vis[(*it)]) {
int &v = *it;
depth[v] = depth[root] + 1;
fa[v] = root;
build_tree(v);
size[root] += size[v];
son[root].push_back(v);
if (size[v] > Max) {
Max = size[v];
ans = v;
}
}
next[root] = ans;
}
void dfs(int root, int chain) { // dfs2
id[root] = ++nz;
top[root] = chain;
if (son[root].empty())
return;
dfs(next[root], chain);
std::vector<int>::iterator it;
for (it = son[root].begin(); it != son[root].end(); it++) {
int &v = *it;
if (v != next[root]) {
dfs(v, v);
}
}
return;
}
int query_max(int num, int ql, int qr) {
int L = seg[num].l;
int R = seg[num].r;
int mid = (L + R) >> 1;
if (L == ql && R == qr)
return seg[num].mx;
else if (qr <= mid)
return query_max(num << 1, ql, qr);
else if (ql > mid)
return query_max(num << 1 | 1, ql, qr);
else
return max(query_max(num << 1, ql, mid),
query_max(num << 1 | 1, mid + 1, qr));
}
int query_sum(int num, int ql, int qr) {
int L = seg[num].l;
int R = seg[num].r;
int mid = (L + R) >> 1;
if (L == ql && R == qr)
return seg[num].sum;
else if (qr <= mid)
return query_sum(num << 1, ql, qr);
else if (ql > mid)
return query_sum(num << 1 | 1, ql, qr);
else
return query_sum(num << 1, ql, mid) + query_sum(num << 1 | 1, mid + 1, qr);
}
int path_max(int x, int y) {
int mx = -0x3f3f3f;
while (top[x] != top[y]) {
if (depth[top[x]] < depth[top[y]])
swap(x, y);
mx = max(mx, query_max(1, id[top[x]], id[x]));
x = fa[top[x]];
}
if (id[x] > id[y])
swap(x, y);
mx = max(mx, query_max(1, id[x], id[y]));
return mx;
}
int path_sum(int u, int v) {
int sum = 0;
while (top[u] != top[v]) {
if (depth[top[u]] < depth[top[v]])
swap(u, v);
sum += query_sum(1, id[top[u]], id[u]);
u = fa[top[u]];
}
if (id[u] > id[v])
swap(u, v);
sum += query_sum(1, id[u], id[v]);
return sum;
}