[ZJOI2008]树的统计
题目描述
一棵树上有 \(n\) 个节点,编号分别为 1 到 \(n\),每个节点都有一个权值 \(w\)。
我们将以下面的形式来要求你对这棵树完成一些操作:
I. CHANGE u t : 把结点 \(u\) 的权值改为 \(t\)。
II. QMAX u v: 询问从点 \(u\) 到点 \(v\) 的路径上的节点的最大权值。
III. QSUM u v: 询问从点 \(u\) 到点 \(v\) 的路径上的节点的权值和。
注意:从点 \(u\) 到点 \(v\) 的路径上的节点包括\(u\) 和 \(v\) 本身。
输入格式
输入文件的第一行为一个整数 \(n\),表示节点的个数。
接下来 \(n-1\) 行,每行 2 个整数 \(a\) 和 \(b\),表示节点 \(a\) 和节点 \(b\) 之间有一条边相连。
接下来一行 \(n\) 个整数,第 \(i\)个整数 \(w_{i}\) 表示节点 \(i\) 的权值。
接下来 1 行,为一个整数\(q\),表示操作的总数。
接下来 \(q\) 行,每行一个操作,以 CHANGE u v 或者 QMAX u v或者 QSUM u v 的形式给出。
输出格式
对于每个QMAX或者 QSUM的操作,每行输出一个整数表示要求输出的结果。
输入输出样例
输入
4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
输出
4
1
2
2
10
6
5
6
5
16
说明/提示
对于 \(100\%\)的数据,保证 \(1\le n \le 3\times 10^4\),\(0\le q\le 2\times 10^5\)。
中途操作中保证每个节点的权值 \(w\) 在 \(-3\times 10^4\) 到\(3\times 10^4\)之间。
\(Sol\)
此题是比较模板的树链剖分了。
第一个操作:
线段树单点修改即可。
第二个操作:
线段树维护一个最大值标记即可。
第三个操作:
线段树维护权值和。
\(Code\)
#include<bits/stdc++.h>
using namespace std;
int head[100001],tot,n,q,w[100001],d[1000001],fa[1000001],size[1000001],sum,top[1000001],pos[10000001];
struct data {
int to,nxt;
} e[1000001];
struct SegmentTree {
int sum,l,r,mx;
#define l(x) tree[x].l
#define r(x) tree[x].r
#define sum(x) tree[x].sum
#define mx(x) tree[x].mx
} tree[4000004];
void add(int x,int y) {
e[++tot].to=y;
e[tot].nxt=head[x];
head[x]=tot;
}
void build(int p,int l,int r) {
l(p)=l,r(p)=r;
if(l==r)
return;
int mid=l+r>>1;
build(p<<1,l,mid);
build(p<<1|1,mid+1,r);
}
void change(int p,int x,int v) {
if(l(p)==r(p)) {
sum(p)=v;
mx(p)=v;
return;
}
int mid=l(p)+r(p)>>1;
if(x<=mid) change(p<<1,x,v);
else change(p<<1|1,x,v);
mx(p)=max(mx(p<<1),mx(p<<1|1));
sum(p)=sum(p<<1)+sum(p<<1|1);
}
int query_sum(int p,int l,int r) {
if(l<=l(p)&&r>=r(p))
return sum(p);
int val=0;
int mid=l(p)+r(p)>>1;
if(l<=mid) val+=query_sum(p<<1,l,r);
if(r>mid) val+=query_sum(p<<1|1,l,r);
return val;
}
int query_mx(int p,int l,int r) {
if(l<=l(p)&&r>=r(p))
return mx(p);
int val=-(1<<30);
int mid=l(p)+r(p)>>1;
if(l<=mid) val=max(val,query_mx(p<<1,l,r));
if(r>mid) val=max(val,query_mx(p<<1|1,l,r));
return val;
}
void dfs1(int x) {
size[x]=1;
for(int i=head[x]; i; i=e[i].nxt) {
int v=e[i].to;
if(v==fa[x]) continue;
fa[v]=x;
d[v]=d[x]+1;
dfs1(v);
size[x]+=size[v];
}
}
void dfs2(int x,int t) { //jiedian top
int node=0;
pos[x]=++sum;
top[x]=t;
for(int i=head[x]; i; i=e[i].nxt)
if(d[e[i].to]>d[x]&&size[e[i].to]>size[node])
node=e[i].to;
if(node==0)
return;
dfs2(node,t);
for(int i=head[x]; i; i=e[i].nxt)
if(d[e[i].to]>d[x]&&node!=e[i].to)
dfs2(e[i].to,e[i].to);
}
int solve_sum(int x,int y) {
int val=0;
while(top[x]!=top[y]) {
if(d[top[x]]<d[top[y]]) swap(x,y);
val+=query_sum(1,pos[top[x]],pos[x]);
x=fa[top[x]];
}
if(pos[x]>pos[y]) swap(x,y);
val+=query_sum(1,pos[x],pos[y]);
return val;
}
int solve_mx(int x,int y) {
int val=-(1<<30);
while(top[x]!=top[y]) {
if(d[top[x]]<d[top[y]]) swap(x,y);
val=max(val,query_mx(1,pos[top[x]],pos[x]));
x=fa[top[x]];
}
if(pos[x]>pos[y]) swap(x,y);
val=max(val,query_mx(1,pos[x],pos[y]));
return val;
}
void init() {
scanf("%d",&n);
for(int i=1,x,y; i<n; i++)
scanf("%d%d",&x,&y),add(x,y),add(y,x);
for(int i=1; i<=n; i++)
scanf("%d",&w[i]);
}
void solve() {
build(1,1,n);
for(int i=1; i<=n; i++)
change(1,pos[i],w[i]);
scanf("%d",&q);
for(int i=1; i<=q; i++) {
char c[10];
int x,y;
scanf("%s%d%d",c,&x,&y);
if(c[0]=='C')
change(1,pos[x],y);
else {
if(c[1]=='M')
printf("%d\n",solve_mx(x,y));
else
printf("%d\n",solve_sum(x,y));
}
}
}
int main() {
init();
dfs1(1);
dfs2(1,1);
solve();
return 0;
}