BZOJ1036 [ZJOI2008]树的统计Count 【树链剖分+线段树维护】
BZOJ1036 [ZJOI2008]树的统计Count
Description
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成
一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I
II. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身
Input
输入的第一行为一个整数n,表示节点的个数。接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有
一条边相连。接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来1行,为一个整数q,表示操作
的总数。接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。
Output
对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。
Sample Input
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
Sample Output
1
2
2
10
6
5
6
5
16
题解:
树链剖分入门题经典题模板题。
包含了区间求值,区间求和,单点修改各种操作。
老套路,先 dfs 1遍,弄出每个节点的深度、子节点数、以及父亲节点。
第二遍 dfs 开始 剖树为链 对于每个点如果是在重链上,则继续延续下去,否则如果是轻边,则从头开始拉链。
(注意 dfs 的时候一定要按重链先 dfs 完后,再轮到一条条轻链,保证 dfs序的有序性,这样才能达到把这些所有的链拉成一个序列从而维护的效果)
然后开始建树。接着对于不同的询问,就是线段树的基本操作了。
然后要提一下的就是怎样在不同链上查找:
对于区间 x,y ,一定要先把 x,y 跳到同一条重链上,这样才能通过数据结构快速求值。跳的时候把深度更深的节点跳到此链的父亲节点上,直到 x,y 跳至同一条重链上。
跳重链的时候对于求 max,直接比较 querymx(pos[x],fa[bl[x]]) ,对于求 sum,直接把 querysum(pos[x],fa[bl[x]]) 累加到统计中。
代码:
1 #include<bits/stdc++.h> 2 #define inf 0x7fffffff 3 #define N 30005 4 #define M 60005 5 using namespace std; 6 int n,q,cnt,sz; 7 int v[N],dep[N],size[N],hea[N],fa[N]; 8 int pos[N],bl[N]; 9 struct data{ 10 int to,next; 11 }e[M]; 12 struct seg{ 13 int l,r,mx,sum; 14 }t[100005]; 15 void insert(int u,int v) 16 { 17 e[++cnt].to=v; e[cnt].next=hea[u]; hea[u]=cnt; 18 e[++cnt].to=u; e[cnt].next=hea[v]; hea[v]=cnt; 19 } 20 void init() 21 { 22 scanf("%d",&n); 23 for (int i=1; i<n; i++) 24 { 25 int x,y; 26 scanf("%d%d",&x,&y); 27 insert(x,y); 28 } 29 for (int i=1; i<=n; i++) scanf("%d",&v[i]); 30 } 31 void dfs1(int x) 32 { 33 size[x]=1; 34 for (int i=hea[x]; i; i=e[i].next) 35 { 36 if (e[i].to==fa[x]) continue; 37 dep[e[i].to]=dep[x]+1; 38 fa[e[i].to]=x; 39 dfs1(e[i].to); 40 size[x]+=size[e[i].to]; 41 } 42 } 43 void dfs2(int x,int chain) 44 { 45 int k=0; sz++; 46 pos[x]=sz; 47 bl[x]=chain; 48 for (int i=hea[x]; i; i=e[i].next) 49 if (dep[e[i].to]>dep[x] && size[e[i].to]>size[k]) 50 k=e[i].to; 51 if (k==0) return; 52 dfs2(k,chain); 53 for (int i=hea[x]; i; i=e[i].next) 54 if (dep[e[i].to]>dep[x] && k!=e[i].to) 55 dfs2(e[i].to,e[i].to); 56 } 57 void build(int k,int l,int r) 58 { 59 t[k].l=l; t[k].r=r; 60 if (l==r) return; 61 int mid=(l+r)>>1; 62 build(k<<1,l,mid); build(k<<1|1,mid+1,r); 63 } 64 void change(int k,int x,int y) 65 { 66 int l=t[k].l,r=t[k].r,mid=(l+r)>>1; 67 if (l==r) { 68 t[k].sum=t[k].mx=y; return; 69 } 70 if (x<=mid) change(k<<1,x,y); else change(k<<1|1,x,y); 71 t[k].sum=t[k<<1].sum+t[k<<1|1].sum; 72 t[k].mx=max(t[k<<1].mx,t[k<<1|1].mx); 73 } 74 int querymx(int k,int x,int y) 75 { 76 int l=t[k].l,r=t[k].r,mid=(l+r)>>1; 77 if (l==x && y==r) return t[k].mx; 78 if (y<=mid) return querymx(k<<1,x,y); 79 else if (x>mid) return querymx(k<<1|1,x,y); 80 else return max(querymx(k<<1,x,mid),querymx(k<<1|1,mid+1,y)); 81 } 82 int querysum(int k,int x,int y) 83 { 84 int l=t[k].l,r=t[k].r,mid=(l+r)>>1; 85 if (l==x && y==r) return t[k].sum; 86 if (y<=mid) return querysum(k<<1,x,y); 87 else if (x>mid) return querysum(k<<1|1,x,y); 88 else return querysum(k<<1,x,mid)+querysum(k<<1|1,mid+1,y); 89 } 90 int solvesum(int x,int y) 91 { 92 int sum=0; 93 while (bl[x]!=bl[y]) 94 { 95 if (dep[bl[x]]<dep[bl[y]]) swap(x,y); 96 sum+=querysum(1,pos[bl[x]],pos[x]); 97 x=fa[bl[x]]; 98 } 99 if (pos[x]>pos[y]) swap(x,y); 100 sum+=querysum(1,pos[x],pos[y]); 101 return sum; 102 } 103 int solvemx(int x,int y) 104 { 105 int mx=-inf; 106 while (bl[x]!=bl[y]) 107 { 108 if (dep[bl[x]]<dep[bl[y]]) swap(x,y); 109 mx=max(mx,querymx(1,pos[bl[x]],pos[x])); 110 x=fa[bl[x]]; 111 } 112 if (pos[x]>pos[y]) swap(x,y); 113 mx=max(mx,querymx(1,pos[x],pos[y])); 114 return mx; 115 } 116 void solve() 117 { 118 build(1,1,n); 119 for (int i=1; i<=n; i++) 120 change(1,pos[i],v[i]); 121 scanf("%d",&q); 122 char ch[10]; 123 for (int i=1; i<=q; i++) 124 { 125 int x,y; 126 scanf("%s%d%d",ch,&x,&y); 127 if (ch[0]=='C') { 128 v[x]=y; change(1,pos[x],y); 129 } 130 else if (ch[1]=='M') printf("%d\n",solvemx(x,y)); 131 else printf("%d\n",solvesum(x,y)); 132 } 133 } 134 int main() 135 { 136 init(); 137 dfs1(1); 138 dfs2(1,1); 139 solve(); 140 return 0; 141 }
加油加油加油!!!fighting fighting fighting!!!