bzoj1036 [ZJOI2008]树的统计Count 树链剖分模板题
[ZJOI2008]树的统计Count
Description
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成
一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I
II. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身
Input
输入的第一行为一个整数n,表示节点的个数。接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有
一条边相连。接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来1行,为一个整数q,表示操作
的总数。接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。
Output
对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。
Sample Input
4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
Sample Output
4
1
2
2
10
6
5
6
5
16
solution:
这非常明显,是一个树链剖分的模板题。
树链剖分就是把一棵树分成若干条链,分为轻链和重链,并把点映射到一个线性数列,然后用线段树维护这个序列。
对于点u,size最大的儿子称为重儿子,其他的儿子叫轻儿子。e(u,v)为重边,u到u的轻儿子的边称为轻边。重链是由重边组成的链。
性质1:u的任意轻儿子v,有size[v]<=size[u]/2
证明:假设size[v]>size[u]/2,设u的重儿子为v’.
size[v’]>=size[v]
size[v’]+size[v]>=2size[v]>size[u]矛盾
性质2:从根到任意点的路径中,轻边与重链的数量都小于等于lg n
证明:
1、假设路径只有轻边,每经过一条轻边,size至少减一半,所以最多lg n条轻边
2、每经过一条重边,size不会变大,所以1在有重边的时候仍然成立
3、连续的重边会连成重链,所以重链数<=轻边数+1
对于每一棵子树,优先搜索重儿子,每一个点的dfs序即为该点在线段树上的位置。
显然,一棵子树中的点是一个连续的区间,一条重链上的点也是一个连续的区间。
先一次dfs1算出每一个点的深度deep,重儿子son,子树大小size,父亲fa
再一次dfs2算出每一个点的w和top,w表示该点在线段树上的标号,top表示该点所在的重链的顶点。
1 void dfs1(int u){ 2 siz[u]=1; 3 son[u]=0; 4 for(int i=head[u];i;i=Next[i]){ 5 int v=vet[i]; 6 if(v!=fa[u]){ 7 fa[v]=u; 8 deep[v]=deep[u]+1; 9 dfs1(v); 10 siz[u]+=siz[v]; 11 if(siz[son[u]]<siz[v]) 12 son[u]=v; 13 } 14 } 15 } 16 void dfs2(int u,int tp){ 17 w[u]=++num; 18 top[u]=tp; 19 if(son[u]) 20 dfs2(son[u],tp); 21 for(int i=head[u];i;i=Next[i]) 22 if(vet[i]!=fa[u]&&vet[i]!=son[u]) 23 dfs2(vet[i],vet[i]); 24 }
然后构建线段树
1 void insert(int u,int l,int r,int i,int val){ 2 if(l==r){ 3 rmax[u]=sum[u]=val; 4 return; 5 } 6 int mid=(l+r)>>1; 7 if(i<=mid) 8 insert(u+u,l,mid,i,val); 9 else 10 insert(u+u+1,mid+1,r,i,val); 11 sum[u]=sum[u+u]+sum[u+u+1]; 12 rmax[u]=max(rmax[u+u],rmax[u+u+1]); 13 }
1 for(int i=1;i<=n+n+n+n+n;i++) 2 rmax[i]=-1000000000; 3 for(int i=1;i<=n;i++){ 4 int x; 5 scanf("%d",&x); 6 insert(1,1,n,w[i],x); 7 }
然后还有查询和修改
单点修改非常简单,直接用线段树的insert
1 void insert(int u,int l,int r,int i,int val){ 2 if(l==r){ 3 rmax[u]=sum[u]=val; 4 return; 5 } 6 int mid=(l+r)>>1; 7 if(i<=mid) 8 insert(u+u,l,mid,i,val); 9 else 10 insert(u+u+1,mid+1,r,i,val); 11 sum[u]=sum[u+u]+sum[u+u+1]; 12 rmax[u]=max(rmax[u+u],rmax[u+u+1]); 13 }
1 scanf("%s",ca); 2 int x,y; 3 scanf("%d%d",&x,&y); 4 if(ca[0]=='C'){ 5 insert(1,1,n,w[x],y); 6 }
查询要分情况考虑
1、top[x]==top[y]
在同一条重链上
直接查询(x,y)区间
2、top[x]!=top[y]
向上跳一步,直到情况1
很不好解释,还是上代码
int query1(int u,int l,int r,int i,int j){ if(l==i&&r==j) return rmax[u]; int mid=(l+r)>>1; if(j<=mid) return query1(u+u,l,mid,i,j); else if(i>mid) return query1(u+u+1,mid+1,r,i,j); else return max(query1(u+u,l,mid,i,mid),query1(u+u+1,mid+1,r,mid+1,j)); } int query2(int u,int l,int r,int i,int j){ if(l==i&&r==j) return sum[u]; int mid=(l+r)>>1; if(j<=mid) return query2(u+u,l,mid,i,j); else if(i>mid) return query2(u+u+1,mid+1,r,i,j); else return query2(u+u,l,mid,i,mid)+query2(u+u+1,mid+1,r,mid+1,j); } while(m--){ scanf("%s",ca); int x,y; scanf("%d%d",&x,&y); if(ca[0]=='C'){ insert(1,1,n,w[x],y); } else if(ca[0]=='Q'&&ca[1]=='M'){ ans=-1000000000; while(top[x]!=top[y]){ int f1=(top[x]==x?fa[x]:top[x]),f2=(top[y]==y?fa[y]:top[y]); if(deep[f1]>=deep[f2]){ if(top[x]==x) ans=max(ans,query1(1,1,n,w[x],w[x])); else ans=max(ans,query1(1,1,n,w[f1]+1,w[x])); x=f1; } else{ if(top[y]==y) ans=max(ans,query1(1,1,n,w[y],w[y])); else ans=max(ans,query1(1,1,n,w[f2]+1,w[y])); y=f2; } } ans=max(ans,query1(1,1,n,min(w[x],w[y]),max(w[x],w[y]))); printf("%d\n",ans); } else{ ans=0; while(top[x]!=top[y]){ int f1=(top[x]==x?fa[x]:top[x]),f2=(top[y]==y?fa[y]:top[y]); if(deep[f1]>=deep[f2]){ if(top[x]==x) ans+=query2(1,1,n,w[x],w[x]); else ans+=query2(1,1,n,w[f1]+1,w[x]); x=f1; } else{ if(top[y]==y) ans+=query2(1,1,n,w[y],w[y]); else ans+=query2(1,1,n,w[f2]+1,w[y]); y=f2; } } ans+=query2(1,1,n,min(w[x],w[y]),max(w[x],w[y])); printf("%d\n",ans); } }
完整的ac代码
1 #include<iostream> 2 #include<cstdio> 3 #include<cmath> 4 #include<algorithm> 5 #include<cstring> 6 #include<cstdlib> 7 using namespace std; 8 int ans,m,n,vet[100000],head[100000],Next[100000],en,fa[100000],deep[100000],siz[100000],son[100000],top[100000],w[100000],num,sum[1000000],rmax[1000000]; 9 char ca[10]; 10 void addedge(int u,int v){ 11 vet[++en]=v; 12 Next[en]=head[u]; 13 head[u]=en; 14 } 15 void dfs1(int u){ 16 siz[u]=1; 17 son[u]=0; 18 for(int i=head[u];i;i=Next[i]){ 19 int v=vet[i]; 20 if(v!=fa[u]){ 21 fa[v]=u; 22 deep[v]=deep[u]+1; 23 dfs1(v); 24 siz[u]+=siz[v]; 25 if(siz[son[u]]<siz[v]) 26 son[u]=v; 27 } 28 } 29 } 30 void dfs2(int u,int tp){ 31 w[u]=++num; 32 top[u]=tp; 33 if(son[u]) 34 dfs2(son[u],tp); 35 for(int i=head[u];i;i=Next[i]) 36 if(vet[i]!=fa[u]&&vet[i]!=son[u]) 37 dfs2(vet[i],vet[i]); 38 } 39 void insert(int u,int l,int r,int i,int val){ 40 if(l==r){ 41 rmax[u]=sum[u]=val; 42 return; 43 } 44 int mid=(l+r)>>1; 45 if(i<=mid) 46 insert(u+u,l,mid,i,val); 47 else 48 insert(u+u+1,mid+1,r,i,val); 49 sum[u]=sum[u+u]+sum[u+u+1]; 50 rmax[u]=max(rmax[u+u],rmax[u+u+1]); 51 } 52 int query1(int u,int l,int r,int i,int j){ 53 if(l==i&&r==j) 54 return rmax[u]; 55 int mid=(l+r)>>1; 56 if(j<=mid) 57 return query1(u+u,l,mid,i,j); 58 else 59 if(i>mid) 60 return query1(u+u+1,mid+1,r,i,j); 61 else 62 return max(query1(u+u,l,mid,i,mid),query1(u+u+1,mid+1,r,mid+1,j)); 63 } 64 int query2(int u,int l,int r,int i,int j){ 65 if(l==i&&r==j) 66 return sum[u]; 67 int mid=(l+r)>>1; 68 if(j<=mid) 69 return query2(u+u,l,mid,i,j); 70 else 71 if(i>mid) 72 return query2(u+u+1,mid+1,r,i,j); 73 else 74 return query2(u+u,l,mid,i,mid)+query2(u+u+1,mid+1,r,mid+1,j); 75 } 76 int main(){ 77 scanf("%d",&n); 78 for(int i=1;i<n;i++){ 79 int x,y; 80 scanf("%d%d",&x,&y); 81 addedge(x,y); 82 addedge(y,x); 83 } 84 deep[1]=1; 85 dfs1(1); 86 dfs2(1,1); 87 for(int i=1;i<=n+n+n+n+n;i++) 88 rmax[i]=-1000000000; 89 for(int i=1;i<=n;i++){ 90 int x; 91 scanf("%d",&x); 92 insert(1,1,n,w[i],x); 93 } 94 scanf("%d",&m); 95 while(m--){ 96 scanf("%s",ca); 97 int x,y; 98 scanf("%d%d",&x,&y); 99 if(ca[0]=='C'){ 100 insert(1,1,n,w[x],y); 101 } 102 else 103 if(ca[0]=='Q'&&ca[1]=='M'){ 104 ans=-1000000000; 105 while(top[x]!=top[y]){ 106 int f1=(top[x]==x?fa[x]:top[x]),f2=(top[y]==y?fa[y]:top[y]); 107 if(deep[f1]>=deep[f2]){ 108 if(top[x]==x) 109 ans=max(ans,query1(1,1,n,w[x],w[x])); 110 else 111 ans=max(ans,query1(1,1,n,w[f1]+1,w[x])); 112 x=f1; 113 } 114 else{ 115 if(top[y]==y) 116 ans=max(ans,query1(1,1,n,w[y],w[y])); 117 else 118 ans=max(ans,query1(1,1,n,w[f2]+1,w[y])); 119 y=f2; 120 } 121 } 122 ans=max(ans,query1(1,1,n,min(w[x],w[y]),max(w[x],w[y]))); 123 printf("%d\n",ans); 124 } 125 else{ 126 ans=0; 127 while(top[x]!=top[y]){ 128 int f1=(top[x]==x?fa[x]:top[x]),f2=(top[y]==y?fa[y]:top[y]); 129 if(deep[f1]>=deep[f2]){ 130 if(top[x]==x) 131 ans+=query2(1,1,n,w[x],w[x]); 132 else 133 ans+=query2(1,1,n,w[f1]+1,w[x]); 134 x=f1; 135 } 136 else{ 137 if(top[y]==y) 138 ans+=query2(1,1,n,w[y],w[y]); 139 else 140 ans+=query2(1,1,n,w[f2]+1,w[y]); 141 y=f2; 142 } 143 } 144 ans+=query2(1,1,n,min(w[x],w[y]),max(w[x],w[y])); 145 printf("%d\n",ans); 146 } 147 } 148 return 0; 149 }