BZOJ - 1036 [ZJOI2008]树的统计Count 树链剖分入门
[ZJOI2008]树的统计Count
escription
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成
一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I
II. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身
Input
输入的第一行为一个整数n,表示节点的个数。接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有
一条边相连。接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来1行,为一个整数q,表示操作
的总数。接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。
Output
对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。
Sample Input
4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
Sample Output
4
1
2
2
10
6
5
6
5
16
1
2
2
10
6
5
6
5
16
思路:树链剖分,分成多条链及多个区间,线段树维护区间值, 板子题~~~
#include<bits/stdc++.h> #define ll long long #define lson l,m,rt<<1 #define rson m+1,r,rt<<1|1 #define pb push_back #define P pair<int,int> #define INF 1e18 using namespace std; const int maxn = 30005; int head[maxn],Next[maxn<<1],To[maxn<<1],fa[maxn],id[maxn],top[maxn],tp,pos[maxn]; int tot,cnt,size[maxn],son[maxn],a[maxn],deep[maxn],n; ll tree[maxn<<2],Max[maxn<<2]; bool fg; void add(int u,int v) { Next[++cnt]=head[u]; head[u]=cnt; To[cnt]=v; } void dfs1(int u,int f,int dep) { fa[u]=f; deep[u]=dep+1; int mx=0; for(int i=head[u]; i!=-1; i=Next[i]) { int v=To[i]; if(v==f) continue; dfs1(v,u,dep+1); size[u]+=size[v]; if(size[v]>mx) mx=size[v],son[u]=v; } size[u]++; } void dfs2(int u,int tp) { id[u]=++tot; pos[tot]=u; top[u]=tp; if(son[u]) dfs2(son[u],tp); for(int i=head[u]; i!=-1; i=Next[i]) { int v=To[i]; if(v!=son[u]&&v!=fa[u]) dfs2(v,v); } } void push_up(int rt) { Max[rt]=max(Max[rt<<1],Max[rt<<1|1]); tree[rt]=tree[rt<<1]+tree[rt<<1|1]; } void build(int l,int r,int rt) { if(l==r) { tree[rt]=a[pos[l]]; Max[rt]=a[pos[l]]; return; } int m=(l+r)>>1; build(lson); build(rson); push_up(rt); } void updata(int l,int r,int rt,int L,int val) { if(l==r) { Max[rt]=val; tree[rt]=val; return ; } int m=(l+r)>>1; if(L<=m) updata(lson,L,val); else updata(rson,L,val); push_up(rt); } ll query(int l,int r,int rt,int L,int R) { if(L<=l&&r<=R) { if(fg) return tree[rt]; else return Max[rt]; } int m=(l+r)>>1; ll sum=0,Maxx=-INF; if(fg) { if(L<=m) sum+=query(lson,L,R); if(R>m) sum+=query(rson,L,R); return sum; } else { if(L<=m) Maxx=max(Maxx,query(lson,L,R)); if(R>m) Maxx=max(Maxx,query(rson,L,R)); return Maxx; } } ll Query(int u,int v) { ll sum=0,Maxx=-INF; while(top[u]!=top[v]) { if(deep[top[u]]<deep[top[v]]) swap(u,v); if(fg) { sum+=query(1,n,1,id[top[u]],id[u]); } else {Maxx=max(Maxx,query(1,n,1,id[top[u]],id[u]));} u=fa[top[u]]; } if(deep[u]>deep[v]) swap(u,v); if(fg) return sum+query(1,n,1,id[u],id[v]); else return max(Maxx,query(1,n,1,id[u],id[v])); } int main() { int u,v,m; scanf("%d",&n); memset(head,-1,sizeof(head)); for(int i=1; i<n; i++) { scanf("%d %d",&u,&v); add(u,v); add(v,u); } for(int i=1; i<=n; i++) { scanf("%d",&a[i]); } dfs1(1,0,0); dfs2(1,1); build(1,n,1); scanf("%d",&m); while(m--) { char s[10]; int u,v; scanf("%s%d%d",s,&u,&v); if(s[0]=='Q') { if(s[1]=='S') fg=1; else fg=0; printf("%lld\n",Query(u,v)); } else { updata(1,n,1,id[u],v); } } return 0; }
PS:摸鱼怪的博客分享,欢迎感谢各路大牛的指点~