树链剖分学习&BZOJ1036
树链剖分,计算机术语,指一种对树进行划分的算法,它先通过轻重边剖分将树分为多条链,保证每个点属于且只属于一条链,然后再通过数据结构(树状数组、SBT、SPLAY、线段树等)来维护每一条链。
以下是几种概念:
常见的路径剖分的方法是轻重树链剖分(启发式剖分)将树中的边分为:轻边和重边 定义size(X)为以X为根的子树的节点个数。 令V为U的儿子节点中size值最大的节点,那么边(U,V)被称为重边,树中重边之外的边被称为轻边。性质:轻边(U,V),size(V)<=size(U)/2。 从根到某一点的路径上,不超过O(logN)条轻边,不超过O(logN)条重路径。
树链剖分是先处理出树上一些链的关系,然后用线段树、树状数组之类。
树链剖分可以处理出一些重链轻链的关系。
查询时可以通过类似LCA的方式求出链上的信息。(在同一条链上可以直接查询)。
code:
/************************************************************** Problem: 1036 User: yekehe Language: C++ Result: Accepted Time:2460 ms Memory:4964 kb ****************************************************************/ #include <cstdio> #include <cstring> #include <algorithm> using namespace std; int read() { char c;while(c=getchar(),(c<'0'||c>'9')&&c!='-'); int x=0,y=1;c=='-'?y=-1:x=c-'0'; while(c=getchar(),c>='0'&&c<='9')x=x*10+c-'0'; return x*y; } const int MAXN=30005; int N,Q,w[MAXN]; int head[MAXN],nxt[MAXN<<1],To[MAXN<<1],cnt; void add(int x,int y) { To[cnt]=y; nxt[cnt]=head[x]; head[x]=cnt; cnt++; } int f[MAXN],son[MAXN],dep[MAXN],siz[MAXN]; int top[MAXN],id[MAXN]; void dfs(int now,int d,int fa) { f[now]=fa,dep[now]=d,siz[now]=1; for(int i=head[now];i!=-1;i=nxt[i]){ if(To[i]==fa)continue; dfs(To[i],d+1,now); siz[now]+=siz[To[i]]; if(siz[To[i]]>siz[son[now]]) son[now]=To[i]; } return ; }//处理出f,dep,siz(siz可以求对子树修改,即修改x~x+siz[x]-1) int cot=0; void build(int now,int tp) { id[now]=++cot,top[now]=tp; if(son[now])build(son[now],tp); for(int i=head[now];i!=-1;i=nxt[i]){ if(To[i]!=son[now]&&To[i]!=f[now]) build(To[i],To[i]); } }//求id(编号),重链的顶端 struct seg{ int x,s; }t[MAXN*4]; void updata(int now,int l,int r,int x,int c) { if(x<l || x>r)return ; if(l==r){t[now]=(seg){c,c};return ;} int mid=l+r>>1; updata(now<<1,l,mid,x,c); updata(now<<1|1,mid+1,r,x,c); t[now].x=max(t[now<<1].x,t[now<<1|1].x); t[now].s=t[now<<1].s+t[now<<1|1].s; } int Qs(int now,int l,int r,int ql,int qr) { if(ql<=l&&qr>=r)return t[now].s; int mid=l+r>>1,ans=0; if(mid>=ql)ans+=Qs(now<<1,l,mid,ql,qr); if(mid<qr) ans+=Qs(now<<1|1,mid+1,r,ql,qr); return ans; } int Qx(int now,int l,int r,int ql,int qr) { if(ql<=l&&qr>=r)return t[now].x; int mid=l+r>>1,ans=-2e9; if(mid>=ql)ans=max(ans,Qx(now<<1,l,mid,ql,qr)); if(mid<qr) ans=max(ans,Qx(now<<1|1,mid+1,r,ql,qr)); return ans; } int findx(int u,int v) { int ans=-2e9; while(top[u]!=top[v]){ if(dep[top[u]]<dep[top[v]])swap(u,v); ans=max(ans,Qx(1,1,cot,id[top[u]],id[u])); u=f[top[u]]; } if(dep[u]<dep[v])swap(u,v); ans=max(ans,Qx(1,1,cot,id[v],id[u])); return ans; }//类似LCA int finds(int u,int v) { int ans=0; while(top[u]!=top[v]){ if(dep[top[u]]<dep[top[v]])swap(u,v); ans+=Qs(1,1,cot,id[top[u]],id[u]); u=f[top[u]]; } if(dep[u]<dep[v])swap(u,v); ans+=Qs(1,1,cot,id[v],id[u]); return ans; } char C[20]; int main() { memset(head,-1,sizeof head); memset(nxt,-1,sizeof nxt); N=read(); register int i,j; for(i=1;i<N;i++){ int x=read(),y=read(); add(x,y),add(y,x); } for(i=1;i<=N;i++)w[i]=read(); dfs(1,1,0),build(1,1); for(i=1;i<=N;i++) updata(1,1,cot,id[i],w[i]); Q=read(); for(i=1;i<=Q;i++){ scanf("%s",C); int x=read(),y=read(); if(C[0]=='C')updata(1,1,cot,id[x],y); else{ if(C[1]=='M')printf("%d\n",findx(x,y)); else printf("%d\n",finds(x,y)); } } return 0; }