BZOJ 1036 [ZJOI2008]树的统计Count | 树链剖分模板
树链剖分的模板题:在点带权树树上维护路径和,最大值和单点修改
这里给出几个定义
以任意点为根,然后记 size (u ) 为以 u 为根的子树的结点个数,令 v 为 u 所有
儿子中 size 值最大的一个儿子,则 ( u , v ) 为重边, v 称为 u 的重儿子。 u 到其余儿子的边为轻边。
根据定义:任何一个点属于且仅属于一条重链(这里一个点也算是重链)
我们称某条路径为重路径(链),当且仅当它全部由重边组成且端点两边没有重边了。
所以我们可以把这个棵树分成若干个重链,经过证明重链的个数不超过O(logn)
以下给出几个性质可以帮助理解:
性质1:如果 ( u , v ) 为轻边,则 size ( v ) <= size ( u ) / 2
性质2:从根到某一点 V 的路径上的轻边个数不大于 O (log n ) 。
性质3:我们称某条路径为重路径(链),当且仅当它全部由重边组成。那么对于
每个点到根的路径上都不超过 O (log n ) 条轻边和 O (log n ) 条重路径。
如果我们可以用数据结构维护每条重链,就可以在O(nlog^2n)的复杂度内完成询问
接下来给出算法的具体实现步骤
核心:用dfs处理dfs序保证每条重链的点在dfs序列的编号连续,用线段树维护每个重链所在dfs序列
用两次DFS计算7个值
fa[x]:x的父亲
deep[x]:x的深度
sz[x]:以x为根的子树大小
son[x]:x的重儿子
top[x]:x所在的重链的深度最小的节点编号(显然重链是一条深度递增的链)
pos[x]:x在序列中的下标
idx[x]:序列的第x位置对应树中节点编号
这个可以用两次dfs维护出来.
接下来考虑我们怎么把(u,v)的路径拆分成若干个重链:
显然找路径是求LCA的过程
我们始终令deep[top[u]]>deep[top[v]],
当top[v]=top[u] 时,显然他们属于同一个重链,我们直接query就好
当top[u]!=top[v]时,我们让u去他top的fa,这样就到了新的重链
而且这次操作可以在logn内完成这部分路径的查询,如此下去,一定能到他们top相同的时候.就OK啦
1 #include<cstdio> 2 #include<algorithm> 3 #include<cstring> 4 #define N 30010 5 #define INF 100000000 6 using namespace std; 7 int ecnt,head[N],q,val[N],a,b,fa[N],deep[N],son[N],sz[N],top[N],tot,pos[N],indx[N],n; 8 //数组的定义见上 9 char s[N]; 10 int read() 11 { 12 int ret=0,neg=1; 13 char j=getchar(); 14 for (;j>'9' || j<'0';j=getchar()) 15 if (j == '-') neg=-1; 16 for (;j>='0' && j<='9';j=getchar()) 17 ret=ret*10+j-'0'; 18 return ret*neg; 19 } 20 struct adj//边 21 { 22 int nxt,v; 23 }e[2*N]; 24 struct node//线段树的点 25 { 26 int l,r,sum,mx; 27 }t[4*N]; 28 inline void add(int u,int v)//加边 29 { 30 e[++ecnt].v=v; 31 e[ecnt].nxt=head[u]; 32 head[u]=ecnt; 33 e[++ecnt].v=u; 34 e[ecnt].nxt=head[v]; 35 head[v]=ecnt; 36 } 37 void dfs1(int x,int father,int depth)//第一次dfs处理深度,子树大小,重儿子是谁 38 { 39 deep[x]=depth,fa[x]=father,sz[x]=1; 40 for (int i=head[x];i;i=e[i].nxt) 41 { 42 int v=e[i].v; 43 if (v==father) continue; 44 dfs1(v,x,depth+1); 45 sz[x]+=sz[v]; 46 if (!son[x] || sz[v]>sz[son[x]]) son[x]=v; 47 } 48 } 49 void dfs2(int x,int TOP)//第二次dfs处理dfs序和top[i] 50 { 51 top[x]=TOP,pos[x]=++tot,indx[pos[x]]=x; 52 if (son[x]!=0) dfs2(son[x],TOP);//首先搜重儿子保证这个重链上的点的dfs序连续 53 for (int i=head[x];i;i=e[i].nxt) 54 if (e[i].v==son[x] || e[i].v==fa[x]) continue; 55 else dfs2(e[i].v,e[i].v); 56 } 57 void pushup(int p)//emmm 58 { 59 t[p].mx=max(t[p<<1].mx,t[p<<1|1].mx); 60 t[p].sum=t[p<<1].sum+t[p<<1|1].sum; 61 } 62 void build(int p,int l,int r)//线段树 63 { 64 t[p].l=l,t[p].r=r; 65 if (l==r) 66 t[p].sum=t[p].mx=val[indx[l]];//当前的左端点实际上要维护是dfs序对应的节点编号 67 else 68 { 69 int mid=l+r>>1; 70 build(p<<1,l,mid); 71 build(p<<1|1,mid+1,r); 72 pushup(p); 73 } 74 } 75 void modify(int p,int l,int k) 76 { 77 if (t[p].l==l && t[p].r==t[p].l) 78 t[p].sum=k,t[p].mx=k; 79 else 80 { 81 int mid=t[p].l+t[p].r>>1; 82 if (l<=mid) modify(p<<1,l,k); 83 else modify(p<<1|1,l,k); 84 pushup(p); 85 } 86 } 87 int querySum(int p,int l,int r) 88 { 89 if (t[p].l==l && t[p].r==r) 90 return t[p].sum; 91 int mid=t[p].l+t[p].r>>1; 92 if (r<=mid) return querySum(p<<1,l,r); 93 if (l>mid) return querySum(p<<1|1,l,r); 94 return querySum(p<<1,l,mid)+querySum(p<<1|1,mid+1,r); 95 } 96 int queryMax(int p,int l,int r) 97 { 98 if (t[p].l==l && t[p].r==r) 99 return t[p].mx; 100 int mid=t[p].l+t[p].r>>1; 101 if (r<=mid) return queryMax(p<<1,l,r); 102 if (l>mid) return queryMax(p<<1|1,l,r); 103 return max(queryMax(p<<1,l,mid),queryMax(p<<1|1,mid+1,r)); 104 } 105 int pathSum(int u,int v) 106 { 107 int ret=0; 108 while (top[u]!=top[v]) 109 { 110 if (deep[top[u]]<deep[top[v]]) swap(u,v);//保证top[u]深度较大 111 ret+=querySum(1,pos[top[u]],pos[u]); 112 u=fa[top[u]]; 113 } 114 if (deep[u]>deep[v]) swap(u,v);//最后别忘了走pos[u]和pos[v]之间的位置 115 return ret+querySum(1,pos[u],pos[v]); 116 } 117 int pathMax(int u,int v) 118 { 119 int ret=-INF; 120 while (top[u]!=top[v]) 121 { 122 if (deep[top[u]]<deep[top[v]]) swap(u,v); 123 ret=max(ret,queryMax(1,pos[top[u]],pos[u])); 124 u=fa[top[u]]; 125 } 126 if (deep[u]>deep[v]) swap(u,v); 127 return max(ret,queryMax(1,pos[u],pos[v])); 128 } 129 int main() 130 { 131 n=read(); 132 for (int i=1;i<n;i++) 133 add(read(),read()); 134 for (int i=1;i<=n;i++) 135 val[i]=read(); 136 dfs1(1,0,0); 137 dfs2(1,1); 138 build(1,1,n); 139 q=read(); 140 while (q--) 141 { 142 scanf("%s%d%d",s,&a,&b); 143 if (s[0]=='C') 144 modify(1,pos[a],b); 145 else if (s[1]=='M') 146 printf("%d\n",pathMax(a,b)); 147 else printf("%d\n",pathSum(a,b)); 148 } 149 return 0; 150 }