树链剖分 学习笔记
前置知识:$dfs$序,线段树
---------------------------------------
我们可以回顾两个问题:
1.树上从$s$到$t$的路径,每个点权值加上$z$。
很简单。遍历整棵树即可。
2.求树上$s$到$t$的权值和。
$LCA$可做。可以利用$LCA$的性质$dis[s]+dis[t]-2*dis[lca]$做即可。时间复杂度$O(n\log n)$。
但是把这两个问题结合起来呢?
每次更改权值后都要重新算一遍$dis$。那么时间复杂度变成$n^2$的了。这时候,我们就需要树链剖分来解决此类问题。
--------------------------------------
树链剖分:把一棵树划分成若干链,转化成若干序列,并用数据结构维护的算法。
概念:
重儿子:父亲节点的所有儿子中子树结点数目最多($size$最大)的结点;
轻儿子:父亲节点中除了重儿子以外的儿子;
重边:父亲结点和重儿子连成的边;
轻边:父亲节点和轻儿子连成的边;
重链:由多条重边连接而成的路径;
轻链:由多条轻边连接而成的路径;
----------------------------------------------------
变量声明:
int size[maxn],son[maxn],dep[maxn],fa[maxn];//大小、重儿子、深度、父亲节点 int top[maxn],dfn[maxn],cnt;//所在重链的顶点、时间戳
由上面的定义可知,重链一定是由轻边连接的。下面我们对树链剖分的复杂度进行证明。
因为每次跳树跳的是轻边,所以每跳一次后,所在树的子树的大小一定至少是原来的二倍。所以单次跳树的复杂度是$O(n\log n)$。总复杂度$O(n\log^2 n)$。
两次$dfs$
第一次$dfs$是进行$dfs$序,把每个结点的子树大小,父亲节点,重儿子和深度处理出来。
inline void dfs_son(int now,int f,int deep) { dep[now]=deep; size[now]=1; fa[now]=f; int maxson=-1; for (int i=head[now];i;i=edge[i].next) { int to=edge[i].to; if (to==f) continue; dfs_son(to,now,deep+1); size[now]+=size[to]; if (size[to]>maxson) maxson=size[to],son[now]=to; } }
第二次$dfs$是记录时间戳,按照优先遍历重儿子的原则,把树处理成若干链。
inline void dfs(int now,int topf) { dfn[now]=++cnt; wt[cnt]=w[now]; top[now]=topf; if (!son[now]) return; dfs(son[now],topf); for (int i=head[now];i;i=edge[i].next) { int to=edge[i].to; if (dfn[to]) continue; dfs(to,to); } }
然后进行线段树建树。代码不贴了。
更新操作:
如果$s$和$t$不在一个链内,那么我们要进行跳树操作。先把当前链的顶端到此结点的路径权值处理了,然后跳到该链顶点的父亲节点。类似于$LCA$的操作。
inline void updrange(int x,int y,int k) { k%=mod; while(top[x]!=top[y]) { if (dep[top[x]]<dep[top[y]]) swap(x,y); update(1,dfn[top[x]],dfn[x],k); x=fa[top[x]]; } if (dep[x]>dep[y]) swap(x,y); update(1,dfn[x],dfn[y],k); }
inline int qrange(int x,int y) { int ans=0; while(top[x]!=top[y]) { if (dep[top[x]]<dep[top[y]]) swap(x,y); ans+=query(1,dfn[top[x]],dfn[x]); ans%=mod; x=fa[top[x]]; } if (dep[x]>dep[y]) swap(x,y); ans+=query(1,dfn[x],dfn[y]); ans%=mod; return ans; }
如果要更新子树,那么就更简单了,只有一行代码:
update(1,dfn[x],dfn[x]+size[x]-1);
至此,树链剖分的所有内容已讲解完毕。
练习题:
都不难,稍微变通一下就是树链剖分的模板题。
模板代码(练习题1):
#include<bits/stdc++.h> #define int long long using namespace std; const int maxn=200005; int n,m,r,mod; int size[maxn],son[maxn],fa[maxn],dep[maxn],w[maxn]; int top[maxn],wt[maxn],dfn[maxn],cnt; int head[maxn*2],jishu; struct node { int next,to; }edge[maxn*2]; int lazy[maxn*4]; struct tre { int l,r,v; }tree[maxn*5]; inline int read() { int x=0,f=1;char ch=getchar(); while(!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();} while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();} return x*f; } inline void add(int from,int to) { edge[++jishu].next=head[from]; edge[jishu].to=to; head[from]=jishu; } inline void dfs_son(int now,int f,int deep) { dep[now]=deep; size[now]=1; fa[now]=f; int maxson=-1; for (int i=head[now];i;i=edge[i].next) { int to=edge[i].to; if (to==f) continue; dfs_son(to,now,deep+1); size[now]+=size[to]; if (size[to]>maxson) maxson=size[to],son[now]=to; } } inline void dfs(int now,int topf) { dfn[now]=++cnt; wt[cnt]=w[now]; top[now]=topf; if (!son[now]) return; dfs(son[now],topf); for (int i=head[now];i;i=edge[i].next) { int to=edge[i].to; if (dfn[to]) continue; dfs(to,to); } } inline void build(int index,int l,int r) { tree[index].l=l; tree[index].r=r; if (l==r) { tree[index].v=wt[l]; tree[index].v%=mod; return; } int mid=(l+r)>>1; build(index*2,l,mid); build(index*2+1,mid+1,r); tree[index].v=(tree[index*2].v+tree[index*2+1].v)%mod; } void pushdown(int index) { lazy[index*2]+=lazy[index]; lazy[index*2+1]+=lazy[index]; tree[index*2].v+=lazy[index]*(tree[index*2].r-tree[index*2].l+1); tree[index*2+1].v+=lazy[index]*(tree[index*2+1].r-tree[index*2+1].l+1); tree[index*2].v%=mod; tree[index*2+1].v%=mod; lazy[index]=0; } inline void update(int index,int l,int r,int k) { if (l<=tree[index].l&&tree[index].r<=r) {lazy[index]+=k;tree[index].v+=k*(tree[index].r-tree[index].l+1);} else{ if (lazy[index]) pushdown(index); int mid=(tree[index].l+tree[index].r)>>1; if (l<=mid) update(index*2,l,r,k); if (r>mid) update(index*2+1,l,r,k); tree[index].v=(tree[index*2].v+tree[index*2+1].v)%mod; } } inline int query(int index,int l,int r) { if (l<=tree[index].l&&tree[index].r<=r) return tree[index].v; else{ if (lazy[index]) pushdown(index); int mid=(tree[index].r+tree[index].l)>>1,res=0; if (l<=mid) res+=query(index*2,l,r); if (r>mid) res+=query(index*2+1,l,r); return res; } } inline void updrange(int x,int y,int k) { k%=mod; while(top[x]!=top[y]) { if (dep[top[x]]<dep[top[y]]) swap(x,y); update(1,dfn[top[x]],dfn[x],k); x=fa[top[x]]; } if (dep[x]>dep[y]) swap(x,y); update(1,dfn[x],dfn[y],k); } inline int qrange(int x,int y) { int ans=0; while(top[x]!=top[y]) { if (dep[top[x]]<dep[top[y]]) swap(x,y); ans+=query(1,dfn[top[x]],dfn[x]); ans%=mod; x=fa[top[x]]; } if (dep[x]>dep[y]) swap(x,y); ans+=query(1,dfn[x],dfn[y]); ans%=mod; return ans; } inline void updson(int x,int k) { k%=mod; update(1,dfn[x],dfn[x]+size[x]-1,k); } inline int qson(int x) { return query(1,dfn[x],dfn[x]+size[x]-1); } signed main() { n=read(),m=read(),r=read(),mod=read(); for (int i=1;i<=n;i++) w[i]=read(); for (int i=1;i<n;i++) { int x=read(),y=read(); add(x,y);add(y,x); } dfs_son(r,0,1); dfs(r,r); build(1,1,n); for (int i=1;i<=m;i++) { int flag=read(),x,y,z; if (flag==1) { x=read(),y=read(),z=read(); updrange(x,y,z); } if (flag==2) { x=read(),y=read(); printf("%lld\n",qrange(x,y)%mod); } if (flag==3) { x=read(),y=read(); updson(x,y); } if (flag==4) { x=read(); printf("%lld\n",qson(x)%mod); } } return 0; }