BZOJ1036 - 1036 树链剖分
1036: [ZJOI2008]树的统计Count
题目大意
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成
一些操作:
I. CHANGE u t : 把结点u的权值改为t
II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值
III. QSUM u v: 询问从点u到点v的路径上的节点的权值和
注意:从点u到点v的路径上的节点包括u和v本身
数据范围
。
解题思路
树链剖分可以了解一下。
AC代码
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<cmath>
#include<algorithm>
#include<vector>
using namespace std;
const int INF = 100000;
const int maxn = 300000;
typedef long long LL;
struct TREE {
int l, r, Max, lazy;
LL sum;
}tree[maxn * 4 + 5];
struct SPLIT {
int size, dep, fa, son, id, top;
}split[maxn + 5];
int rid[maxn + 5], w[maxn + 5];
vector<int>E[maxn + 5];
int cnt;
//线段树
void Build(int L, int R, int x) {
tree[x].l = L; tree[x].r = R, tree[x].lazy = 0, tree[x].Max = -INF;
if(L == R) {
tree[x].Max = w[rid[L]];
tree[x].sum = (LL)w[rid[L]];
tree[x].lazy = 0;
return ;
}
int mid = (L + R) / 2;
Build(L, mid, x * 2);
Build(mid + 1, R, x * 2 + 1);
tree[x].sum = tree[x * 2].sum + tree[x * 2 + 1].sum;
tree[x].Max = max(tree[x * 2].Max, tree[x * 2 + 1].Max);
}
void PushDown(int x) {
if(tree[x].lazy) {
tree[x * 2].lazy = tree[x].lazy;
tree[x * 2 + 1].lazy = tree[x].lazy;
tree[x * 2].sum = (LL)(tree[x * 2].r - tree[x * 2].l + 1) * (LL)tree[x].lazy;
tree[x * 2 + 1].sum = (LL)(tree[x * 2 + 1].r - tree[x * 2 + 1].l + 1) * (LL)tree[x].lazy;
tree[x].lazy = 0;
}
}
LL QuerySum(int L, int R, int x) {
if(L <= tree[x].l && tree[x].r <= R)return tree[x].sum;
PushDown(x);
int mid = (tree[x].l + tree[x].r) / 2;
LL res = 0;
if(L <= mid)res += QuerySum(L, R, x * 2);
if(R > mid)res += QuerySum(L, R, x * 2 + 1);
return res;
}
int QueryMax(int L, int R, int x) {
if(L <= tree[x].l && tree[x].r <= R)return tree[x].Max;
PushDown(x);
int mid = (tree[x].l + tree[x].r) / 2;
int Max = -INF;
if(L <= mid)Max = max(Max, QueryMax(L, R, x * 2));
if(R > mid)Max = max(Max, QueryMax(L, R, x * 2 + 1));
return Max;
}
void Update(int L, int R, int num, int x) {
if(L <= tree[x].l && tree[x].r <= R) {
tree[x].sum = (LL)(tree[x].r - tree[x].l + 1) * num;
tree[x].Max = num;
tree[x].lazy = num;
return ;
}
PushDown(x);
int mid = (tree[x].l + tree[x].r) / 2;
if(L <= mid)Update(L, R, num, x * 2);
if(R > mid)Update(L, R, num, x * 2 + 1);
tree[x].sum = tree[x * 2].sum + tree[x * 2 + 1].sum;
tree[x].Max = max(tree[x * 2].Max, tree[x * 2 + 1].Max);
}
//处理出dep(深度),fa(父结点),size(子树大小),son(重结点)
void Dfs1(int u, int father, int depth) {
split[u].dep = depth;
split[u].fa = father;
split[u].size = 1;
for(int i = 0; i < E[u].size(); i++) {
int v = E[u][i];
if(v != split[u].fa) {
Dfs1(v, u, depth + 1);
split[u].size += split[v].size;
//如果没有被访问或者该节点子树更大,更新重结点
if(split[u].son == -1 || split[v].size > split[split[u].son].size)split[u].son = v;
}
}
}
//处理处,top(每条重链的顶端结点),id(每个结点剖分的编号,也就是Dfs的执行顺序),rid(该编号对应的结点)
void Dfs2(int u, int sta) {
split[u].top = sta;
split[u].id = cnt;
rid[cnt] = u;
cnt++;
if(split[u].son == -1)return ;//叶子结点
Dfs2(split[u].son, sta);//找出一条连续的重链
for(int i = 0; i < E[u].size(); i++) {
int v = E[u][i];
if(v != split[u].son && v != split[u].fa)Dfs2(v, v);//v不是重结点且不是父结点,重新从v开始找重链
}
}
//类似LCA
LL QueryPathSum(int x, int y) {
LL res = 0;
int fx = split[x].top, fy = split[y].top;
while(fx != fy) {//判断是否在同一条重链上
if(split[fx].dep >= split[fy].dep) {
res += QuerySum(split[fx].id, split[x].id, 1);
x = split[fx].fa;//走轻边
}
else {
res += QuerySum(split[fy].id, split[y].id, 1);
y = split[fy].fa;
}
fx = split[x].top, fy = split[y].top;
}
if(split[x].id < split[y].id)res += QuerySum(split[x].id, split[y].id, 1);
else res += QuerySum(split[y].id, split[x].id, 1);
return res;
}
int QueryPathMax(int x, int y) {
int Max = -INF;
int fx = split[x].top, fy = split[y].top;
while(fx != fy) {
if(split[fx].dep >= split[fy].dep) {
Max = max(Max, QueryMax(split[fx].id, split[x].id, 1));
x = split[fx].fa;
}
else {
Max = max(Max, QueryMax(split[fy].id, split[y].id, 1));
y = split[fy].fa;
}
fx = split[x].top, fy = split[y].top;
}
if(split[x].id < split[y].id)Max = max(Max, QueryMax(split[x].id, split[y].id, 1));
else Max = max(Max, QueryMax(split[y].id, split[x].id, 1));
return Max;
}
void UpdatePath(int x, int y, int z) {
int fx = split[x].fa, fy = split[y].fa;
while(fx != fy) {
if(split[fx].dep >= split[fy].dep) {
Update(split[fx].id, split[x].id, z, 1);
x = split[fx].fa;
}
else {
Update(split[fy].id, split[y].id, z, 1);
y = split[fy].id;
}
fx = split[x].top, fy = split[y].top;
}
if(split[x].id <= split[y].id)Update(split[x].id, split[y].id, z, 1);
else Update(split[y].id, split[x].id, z, 1);
}
int n, q;
char s[10];
int main() {
scanf("%d", &n);
for(int i = 1; i < n; i++) {
int u, v; scanf("%d%d", &u, &v);
E[u].push_back(v);
E[v].push_back(u);
}
for(int i = 1; i <= n; i++)scanf("%d", &w[i]);
for(int i = 1; i <= n; i++)split[i].son = -1;
cnt = 1, Dfs1(1, -1, 1); Dfs2(1, 1);
Build(1, n, 1);
scanf("%d", &q);
for(int i = 1; i <= q; i++) {
int x, y;
scanf("%s%d%d", s, &x, &y);
if(s[1] == 'M')printf("%d\n", QueryPathMax(x, y));
else if(s[1] == 'S')printf("%lld\n", QueryPathSum(x, y));
else UpdatePath(x, x, y);
}
return 0;
}