BZOJ 1036 - 树链剖分 模板题
今天终于入手了期盼已久的Macbook Pro,十分高兴。。
本题是树链剖分的模板题,下面对这种神奇的方法作个理解。。
首先是动机:如果我们要维护树上的信息(点权或边权)并在线回答两点间的路径和或最大权,我们可以采用树链剖分的方法(如果只需要回答路径和,也可以通过欧拉序列+RMQ+LCA的方式实现)。——什么是树链剖分呢?在一维序列中,我们维护这些值常常采用线段树,那么在树上我们可否也加以拓展呢?
(干货开始。。
答案是肯定的!但是在这之前,我们需要解决两个问题:第一个是如何把树上的一段连续区间转化为可查询的序列中的连续区间;第二个是如何保证查询次数不会太多。这两个问题似乎是相互牵制的,而轻重树链剖分则很好地平衡了它们。它是这么做的:对于一个节点x,设它和以它为根的子树中的节点数量为S(x),定义一个节点的“重儿子”为S值最大的那个儿子,定义一个节点到它的重儿子的边为重边,则由重边组成的路径称之为重路径。那么,x到y的路径就可以视作若干条重路径和重路径之间的轻边组成。对于一条重路径中的元素和和最大值,我们可以用就数据结构去维护了。但是在实际代码中,我们并不把是一棵线段树维护一条重路径,而是利用DFS序,把所有重路径装到一棵线段树中。具体是这样的:我们从根节点DFS,按照DFS序赋予每个点x一个编号pos[x]。这样,每一条重链的pos[x]就是连续的——因为我们先对重儿子DFS,而后再对剩下的儿子重开重链(具体见代码)。修改的时候,由于本题只牵涉到单点修改,因此我们在数据结构上直接修改即可。查询时,如果两个点在同一条重链上(必要条件是这两个点有祖孙关系),直接在数据结构上查询即可;否则,我们需要先求出它们的LCA。设我们查询max(x to y),LCA(x, y)=a,那么我们需要分别对max(x to a)和(y to a)求解。如果x和a仍不在一条重链上,我们就需要把x“跳”到本条链的顶端(代码中的bel[x]),然后再跳一条轻边。在跳的过程中,顺便维护答案。询问sum同理,只是别忘了减去w[a]。到这里,还有一点小疑问:跳一条轻边是否足够?答案是否定的,我们可能需要连续跳多条轻边,但这并不违反规则——对于这些边上的点,显然有bel[x]=x。那么问题又来了:连续跳轻边会不会使时间复杂度退化?答案依然是否定的。画个图即可知道,如果存在一条很长的链,它全部由轻边组成,那么这条链越长(深度越大),这棵树的节点数就会呈平方级增长,超过了题目的数据范围;换句话说,对于题目的数据范围,如果存在这种极端情况,这棵树的深度也是很小的,不用担心。
然而这道题对着黄学长的代码码的,还是打错了一个变量名,查了半个小时,羞耻MAX。。
第一次A了之后遵循了VFK的神谕和bkq的忠告,按照自己的理解重新码了一遍,感觉效果好了很多。。
这种方法对之后的学习仍然适用~
<span style="font-size:12px;">// BZOJ 1036 // Tree Chain Spilt #include <cstdio> #include <cstring> #include <algorithm> using namespace std; const int N=30000+5, M=N*2, INF=0x7f7f7f7f; int n, Q, sz, u, v, x, y; int fa[N][15], w[N], dep[N], size[N], pos[N], bel[N]; #define rep(i,a,b) for (int i=a; i<=b; i++) #define dep(i,a,b) for (int i=a; i>=b; i--) #define read(x) scanf("%d", &x) #define fill(a,x) memset(a, x, sizeof(a)) struct Node { int l, r, maxv, sumv; } node[N*4]; struct Graph { int s, from[M], to[M], pre[M], last[N]; void init() { s=-1; fill(last, -1); } void ine(int a, int b) { s++; from[s]=a, to[s]=b, pre[s]=last[a]; last[a]=s; } void ine2(int a, int b) { ine(a, b); ine(b, a); } } G; #define reg(i,G,u) for (int i=G.last[u]; i!=-1; i=G.pre[i]) void change(int x, int dad) { // 无根树转有根树 + 倍增预处理祖先 size[x]=1; rep(i,1,14) { if (dep[x]<(1<<i)) break; fa[x][i]=fa[fa[x][i-1]][i-1]; } reg(i,G,x) { int y=G.to[i]; if (y==dad) continue; dep[y]=dep[x]+1; fa[y][0]=x; change(y, x); size[x]+=size[y]; } } void build(int o, int L, int R) { // 初始化线段树 node[o].l=L; node[o].r=R; if (L==R) return; int M=(L+R)>>1; build(o<<1, L, M); build(o<<1|1, M+1, R); } void update(int o, int x, int y) { // 线段树单点修改 int L=node[o].l, R=node[o].r, M=(L+R)>>1, ls=o<<1, rs=o<<1|1; if (L==R) { node[o].sumv=node[o].maxv=y; return; } if (x<=M) update(ls, x, y); else update(rs, x, y); node[o].sumv=node[ls].sumv+node[rs].sumv; node[o].maxv=max(node[ls].maxv, node[rs].maxv); } int query_max(int o, int x, int y) { int L=node[o].l, R=node[o].r, M=(L+R)>>1, ls=o<<1, rs=o<<1|1; if (x<=L && R<=y) return node[o].maxv; if (y<=M) return query_max(ls, x, y); else if (x>M) return query_max(rs, x, y); else return max(query_max(ls, x, M), query_max(rs, M+1, y)); } int query_sum(int o, int x, int y) { int L=node[o].l, R=node[o].r, M=(L+R)>>1, ls=o<<1, rs=o<<1|1; if (x<=L && R<=y) return node[o].sumv; if (y<=M) return query_sum(ls, x, y); else if (x>M) return query_sum(rs, x, y); else return query_sum(ls, x, M)+query_sum(rs, M+1, y); } void init(int x, int chain) { int k=0; sz++; pos[x]=sz; // 分配x节点在线段树中的编号。树链剖分中,保证了同一条链的编号是连续的 bel[x]=chain; // 属于链chain reg(i,G,x) { int y=G.to[i]; if (dep[y]>dep[x] && size[y]>size[k]) k=y; // 选择最重的儿子继承重链 } if (k==0) return; // x的为叶节点 init(k, chain); // 继承重链 reg(i,G,x) { int y=G.to[i]; if (dep[y]>dep[x] && k!=y) init(y, y); // 其余儿子新开重链,编号为y } } int solve_max(int x, int a) { int ret=-INF; while (bel[x]!=bel[a]) { // 不在一条重链上就将x跳到链首,走一条轻边,如此反复 ret=max(ret, query_max(1, pos[bel[x]], pos[x])); // pos[bel[x]]即为x节点所在链的顶端节点 x=fa[bel[x]][0]; } ret=max(ret, query_max(1, pos[a], pos[x])); return ret; } int solve_sum(int x, int a) { // 当然solve_sum也可以用欧拉序列实现,但是我比较懒…… int ret=0; while (bel[x]!=bel[a]) { ret+=query_sum(1, pos[bel[x]], pos[x]); x=fa[bel[x]][0]; } ret+=query_sum(1, pos[a], pos[x]); return ret; } int LCA(int x, int y) { if (dep[x]>dep[y]) swap(x, y); int delta=dep[y]-dep[x]; rep(i,0,14) if (delta&(1<<i)) y=fa[y][i]; dep(i,14,0) if (fa[x][i]!=fa[y][i]) x=fa[x][i], y=fa[y][i]; if (x==y) return x; else return fa[x][0]; } char ch[6]; void solve() { build(1, 1, n); rep(i,1,n) update(1, pos[i], w[i]); read(Q); rep(i,1,Q) { scanf("%s%d%d", ch, &x, &y); if (ch[0]=='C') { w[x]=y; update(1, pos[x], y); } else { int a=LCA(x, y); if (ch[1]=='M') printf("%d\n", max(solve_max(x, a), solve_max(y, a))); else printf("%d\n", solve_sum(x, a)+solve_sum(y, a)-w[a]); } } } int main() { read(n); G.init(); rep(i,1,n-1) read(u), read(v), G.ine2(u, v); rep(i,1,n) read(w[i]); change(1, 0); init(1, 1); solve(); return 0; }</span>