BZOJ1036[ZJOI2008]树的统计——树链剖分+线段树
题目描述
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成
一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I
II. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身
输入
输入的第一行为一个整数n,表示节点的个数。接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有
一条边相连。接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来1行,为一个整数q,表示操作
的总数。接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。
输出
对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。
样例输入
4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
样例输出
4
1
2
2
10
6
5
6
5
16
1
2
2
10
6
5
6
5
16
单点修改、路径求最大值、路径求和,直接上树链剖分,但要注意求最大值时因为可能有负数,所以最小值要设成-INF。
#include<set> #include<map> #include<queue> #include<stack> #include<cmath> #include<vector> #include<cstdio> #include<cstring> #include<iostream> #include<algorithm> #define ll long long using namespace std; int n,m; int tot; int num; int x,y; char ch[30]; int f[30010]; int d[30010]; int s[30010]; int to[60010]; int mx[240010]; int son[30010]; int top[30010]; int size[30010]; int head[30010]; int next[30010]; int sum[240010]; void add(int x,int y) { tot++; next[tot]=head[x]; head[x]=tot; to[tot]=y; } void dfs(int x,int fa) { size[x]=1; f[x]=fa; d[x]=d[fa]+1; for(int i=head[x];i;i=next[i]) { if(to[i]!=fa) { dfs(to[i],x); size[x]+=size[to[i]]; if(size[to[i]]>size[son[x]]) { son[x]=to[i]; } } } } void dfs2(int x,int tp) { s[x]=++num; top[x]=tp; if(son[x]) { dfs2(son[x],tp); } for(int i=head[x];i;i=next[i]) { if(to[i]!=f[x]&&to[i]!=son[x]) { dfs2(to[i],to[i]); } } } void updata(int rt) { sum[rt]=sum[rt<<1]+sum[rt<<1|1]; mx[rt]=max(mx[rt<<1],mx[rt<<1|1]); } void change(int rt,int l,int r,int k,int v) { if(l==r) { sum[rt]=v; mx[rt]=v; return ; } int mid=(l+r)>>1; if(k<=mid) { change(rt<<1,l,mid,k,v); } else { change(rt<<1|1,mid+1,r,k,v); } updata(rt); } int querysum(int rt,int l,int r,int L,int R) { if(L<=l&&r<=R) { return sum[rt]; } int mid=(l+r)>>1; int res=0; if(L<=mid) { res+=querysum(rt<<1,l,mid,L,R); } if(R>mid) { res+=querysum(rt<<1|1,mid+1,r,L,R); } return res; } int querymax(int rt,int l,int r,int L,int R) { if(L<=l&&r<=R) { return mx[rt]; } int mid=(l+r)>>1; if(R<=mid) { return querymax(rt<<1,l,mid,L,R); } else if(L>mid) { return querymax(rt<<1|1,mid+1,r,L,R); } return max(querymax(rt<<1,l,mid,L,R),querymax(rt<<1|1,mid+1,r,L,R)); } int asksum(int x,int y) { int res=0; while(top[x]!=top[y]) { if(d[top[x]]<d[top[y]]) { swap(x,y); } res+=querysum(1,1,n,s[top[x]],s[x]); x=f[top[x]]; } if(d[x]>d[y]) { swap(x,y); } res+=querysum(1,1,n,s[x],s[y]); return res; } int askmax(int x,int y) { int res=-2147483647; while(top[x]!=top[y]) { if(d[top[x]]<d[top[y]]) { swap(x,y); } res=max(res,querymax(1,1,n,s[top[x]],s[x])); x=f[top[x]]; } if(d[x]>d[y]) { swap(x,y); } res=max(res,querymax(1,1,n,s[x],s[y])); return res; } int main() { scanf("%d",&n); for(int i=1;i<n;i++) { scanf("%d%d",&x,&y); add(x,y); add(y,x); } dfs(1,1); dfs2(1,1); for(int i=1;i<=n;i++) { scanf("%d",&x); change(1,1,n,s[i],x); } scanf("%d",&m); for(int i=1;i<=m;i++) { scanf("%s",ch); scanf("%d%d",&x,&y); if(ch[1]=='H') { change(1,1,n,s[x],y); } else if(ch[1]=='M') { printf("%d\n",askmax(x,y)); } else { printf("%d\n",asksum(x,y)); } } }