树链剖分 入门
开始总结一下树剖!树剖好像大概两个月前就接触了……当时没怎么听懂,也许是底子太差了。这段时间多练了练树形DP之类的,现在觉得非常好理解。
树链剖分,简单的来说就是把树拆成链分别维护,以降低维护的复杂度。
首先要说一下,树剖本身并不是什么难的东西,其实只要两遍dfs,至于代码为什么特别长,那是因为在里面加上了线段树(至少也是60行左右吧,操作要多会更长)所以给人产生了代码很长很难理解的想法,但事实并非如此。
我们先来说一下树剖的基本思路,首先有几个概念:
重儿子:子树节点数目最多的儿子。
轻儿子:一个点的儿子中除了重儿子之外的儿子。
重边:父亲节点和重儿子连成的边。
轻边:父亲节点和轻儿子练成的边。
重链:重边连成的路径。
轻链:轻边连成的路径。
于是有这么一条性质:从根结点到任意节点的路径所经过的轻重链个数之和不超过logn,n为节点个数总数。
下面我们来说一下怎么实现,其实非常简单,只要两次dfs。
第一次因为什么都不知道,所以直接从根结点向下dfs,记录每个节点的深度,子树大小,这个节点的重儿子是谁(就是儿子里面size最大的那个),这个节点的父亲是谁。
第一次dfs之后,我们进行第二次dfs,从根结点开始,这次的目的不大一样,我们要按照先重后轻的顺序,如果该节点有重儿子就先dfs重儿子,然后再dfs轻儿子。这时我们要记录一条链的顶端是谁,所以我们把一条重链的顶端一直向下传,而对于一条轻链,我们则以这个链的起点为顶端。然后同时统计dfs序。
经过这样一波操作,我们完成了这样一件事,保证一棵子树内所有节点编号连续,保证一条链上所有节点编号连续。这个有什么用呢?我们结合线段树等数据结构就能立即想到,可以直接进行区间修改,从而更改一棵子树内的所有节点!
这样的话后两个操作我们实际上已经解决了,直接调用dfs序,在线段树上区间修改区间求和即可。
再看前两个操作,这个比较麻烦,我们可能朴素的想法是求LCA然后暴力修改,但是这样很慢。
我们想到刚才这样一个性质,所以链上节点的编号都是连续的!而且我们还在上面记录了每条链的顶端,这样的话我们可以不断的令深度比较大的点沿着自己所在的链向上跳,在跳的同时修改从当前点到链顶的信息(还是线段树区间修改!),跳到链顶之后,走到链顶节点的父亲继续向上跳,直到两个点在一条链上,再次进行一次区间修改即可,复杂度O(nlog^2n)。
有人可能会有疑惑,这个思路和LCA很像,为什么LCA就不行呢?其实树链剖分它是机智的使用了分重轻儿子,并且在dfs的时候先重后轻,使得每条链上的节点都是连续的,但是如果求LCA就无法保证是连续的,这样就不能进行区间修改。
所以简单的树链剖分其实就是树dfs+线段树!是不是很容易就入门啦!
我们来看一下这题代码。(PS:Dukelv有一份蜜汁70ptsTLE代码,不知道有没有哪位大神愿意看看:传送门)
#include<iostream> #include<cstdio> #include<cmath> #include<algorithm> #include<queue> #include<cstring> #define rep(i,a,n) for(int i = a;i <= n;i++) #define per(i,n,a) for(int i = n;i >= a;i--) #define enter putchar('\n') #define pr pair<int,int> #define mp make_pair #define fi first #define sc second using namespace std; typedef long long ll; const int M = 100005; const int N = 10000005; int read() { int ans = 0,op = 1; char ch = getchar(); while(ch < '0' || ch > '9') { if(ch == '-') op = -1; ch = getchar(); } while(ch >='0' && ch <= '9') { ans *= 10; ans += ch - '0'; ch = getchar(); } return ans * op; } struct seg { ll v,lazy; }t[M<<2]; struct edge { int next,to; }e[M<<1]; int n,m,r,size[M],top[M],dep[M],hson[M],fa[M],dfn[M],head[M],ecnt,mod,a[M]; int op,x,y,z,rk[M],cnt; void add(int x,int y) { e[++ecnt].to = y; e[ecnt].next = head[x]; head[x] = ecnt; } void dfs1(int x,int f,int depth) { fa[x] = f,size[x] = 1,dep[x] = depth; int maxson = -1; for(int i = head[x];i;i = e[i].next) { if(e[i].to == f) continue; dfs1(e[i].to,x,depth+1); size[x] += size[e[i].to]; if(size[e[i].to] > maxson) maxson = size[e[i].to],hson[x] = e[i].to; } } void dfs2(int x,int t) { dfn[x] = ++cnt,rk[cnt] = a[x]; top[x] = t; if(!hson[x]) return; dfs2(hson[x],t); for(int i = head[x];i;i = e[i].next) { if(e[i].to == fa[x] || e[i].to == hson[x]) continue; dfs2(e[i].to,e[i].to); } } void pushdown(int p,int l,int r) { if(t[p].lazy) { int mid = (l+r) >> 1; t[p<<1].v += t[p].lazy * (mid-l+1),t[p<<1].v %= mod; t[p<<1|1].v += t[p].lazy * (r-mid),t[p<<1|1].v %= mod; t[p<<1].lazy += t[p].lazy,t[p<<1].lazy %= mod; t[p<<1|1].lazy += t[p].lazy,t[p<<1|1].lazy %= mod; t[p].lazy = 0; } } void build(int p,int l,int r) { if(l == r) { t[p].v = (ll)rk[l] % mod; return; } int mid = (l+r) >> 1; build(p<<1,l,mid),build(p<<1|1,mid+1,r); t[p].v = (t[p<<1].v + t[p<<1|1].v) % mod; } void update(int p,int l,int r,int kl,int kr,int val) { if(l == kl && r == kr) { t[p].lazy += (ll)val,t[p].v += (ll)val * (r-l+1); t[p].lazy %= mod,t[p].v %= mod; return; } int mid = (l+r) >> 1; pushdown(p,l,r); if(kr <= mid) update(p<<1,l,mid,kl,kr,val); else if(kl > mid) update(p<<1|1,mid+1,r,kl,kr,val); else update(p<<1,l,mid,kl,mid,val),update(p<<1|1,mid+1,r,mid+1,kr,val); t[p].v = (t[p<<1].v + t[p<<1|1].v) % mod; } ll query(int p,int l,int r,int kl,int kr) { if(l == kl && r == kr) return t[p].v % mod; int mid = (l+r) >> 1; pushdown(p,l,r); if(kr <= mid) return query(p<<1,l,mid,kl,kr); else if(kl > mid) return query(p<<1|1,mid+1,r,kl,kr); else return (query(p<<1,l,mid,kl,mid) + query(p<<1|1,mid+1,r,mid+1,kr)) % mod; } void uprange(int x,int y,int val) { val %= mod; while(top[x] != top[y]) { if(dep[top[x]] < dep[top[y]]) swap(x,y); update(1,1,n,dfn[top[x]],dfn[x],val); x = fa[top[x]]; } if(dep[x] > dep[y]) swap(x,y); update(1,1,n,dfn[x],dfn[y],val); } ll qrange(int x,int y) { ll ans = 0; while(top[x] != top[y]) { if(dep[top[x]] < dep[top[y]]) swap(x,y); ans += query(1,1,n,dfn[top[x]],dfn[x]); x = fa[top[x]]; } if(dep[x] > dep[y]) swap(x,y); ans += query(1,1,n,dfn[x],dfn[y]); return ans % mod; } void upson(int x,int val) { update(1,1,n,dfn[x],dfn[x] + size[x] - 1,val); } ll qson(int x) { return query(1,1,n,dfn[x],dfn[x] + size[x] - 1); } int main() { n = read(),m = read(),r = read(),mod = read(); rep(i,1,n) a[i] = read(); rep(i,1,n-1) x = read(),y = read(),add(x,y),add(y,x); dfs1(r,0,1),dfs2(r,r); build(1,1,n); while(m--) { op = read(); if(op == 1) x = read(),y = read(),z = read(),uprange(x,y,z); else if(op == 2) x = read(),y = read(),printf("%lld\n",qrange(x,y)); else if(op == 3) x = read(),z = read(),upson(x,z); else if(op == 4) x = read(),printf("%lld\n",qson(x)); } return 0; }
下面我们来看两道例题!
1.ZJOI2008 树的统计 传送门
这题显然是板子题,其实比上面的模板都好写。
#include<iostream> #include<cstdio> #include<cmath> #include<algorithm> #include<queue> #include<cstring> #define rep(i,a,n) for(int i = a;i <= n;i++) #define per(i,n,a) for(int i = n;i >= a;i--) #define enter putchar('\n') #define pr pair<int,int> #define mp make_pair #define fi first #define sc second using namespace std; typedef long long ll; const int M = 100005; const int N = 10000005; int read() { int ans = 0,op = 1; char ch = getchar(); while(ch < '0' || ch > '9') { if(ch == '-') op = -1; ch = getchar(); } while(ch >='0' && ch <= '9') { ans *= 10; ans += ch - '0'; ch = getchar(); } return ans * op; } struct seg { int maxn,v; }t[M<<2]; struct edge { int next,to; }e[M<<1]; int n,x,y,dfn[M],rk[M],a[M],size[M],hson[M],dep[M],top[M],fa[M],head[M],ecnt,idx; int q; char s[10]; void add(int x,int y) { e[++ecnt].to = y; e[ecnt].next = head[x]; head[x] = ecnt; } void dfs1(int x,int f,int depth) { size[x] = 1,fa[x] = f,dep[x] = depth; int maxson = -1; for(int i = head[x];i;i = e[i].next) { if(e[i].to == f) continue; dfs1(e[i].to,x,depth+1); size[x] += size[e[i].to]; if(size[e[i].to] > maxson) maxson = size[e[i].to],hson[x] = e[i].to; } } void dfs2(int x,int t) { dfn[x] = ++idx,rk[idx] = a[x],top[x] = t; if(!hson[x]) return; dfs2(hson[x],t); for(int i = head[x];i;i = e[i].next) { if(e[i].to == fa[x] || e[i].to == hson[x]) continue; dfs2(e[i].to,e[i].to); } } void build(int p,int l,int r) { if(l == r) { t[p].v = t[p].maxn = rk[l]; return; } int mid = (l+r) >> 1; build(p<<1,l,mid),build(p<<1|1,mid+1,r); t[p].v = t[p<<1].v + t[p<<1|1].v; t[p].maxn = max(t[p<<1].maxn,t[p<<1|1].maxn); } void modify(int p,int l,int r,int x,int val) { if(l == r) { t[p].v = t[p].maxn = val; return; } int mid = (l+r) >> 1; if(x <= mid) modify(p<<1,l,mid,x,val); else modify(p<<1|1,mid+1,r,x,val); t[p].v = t[p<<1].v + t[p<<1|1].v; t[p].maxn = max(t[p<<1].maxn,t[p<<1|1].maxn); } int querys(int p,int l,int r,int kl,int kr) { if(l == kl && r == kr) return t[p].v; int mid = (l+r) >> 1; if(kr<= mid) return querys(p<<1,l,mid,kl,kr); else if(kl > mid) return querys(p<<1|1,mid+1,r,kl,kr); else return querys(p<<1,l,mid,kl,mid) + querys(p<<1|1,mid+1,r,mid+1,kr); } int querym(int p,int l,int r,int kl,int kr) { if(l == kl && r == kr) return t[p].maxn; int mid = (l+r) >> 1; if(kr <= mid) return querym(p<<1,l,mid,kl,kr); else if(kl > mid) return querym(p<<1|1,mid+1,r,kl,kr); else return max(querym(p<<1,l,mid,kl,mid),querym(p<<1|1,mid+1,r,mid+1,kr)); } int mrange(int x,int y) { int cur = -100000; while(top[x] != top[y]) { if(dep[top[x]] < dep[top[y]]) swap(x,y); cur = max(cur,querym(1,1,n,dfn[top[x]],dfn[x])); x = fa[top[x]]; } if(dep[x] > dep[y]) swap(x,y); cur = max(cur,querym(1,1,n,dfn[x],dfn[y])); return cur; } int qrange(int x,int y) { int ans = 0; while(top[x] != top[y]) { if(dep[top[x]] < dep[top[y]]) swap(x,y); ans += querys(1,1,n,dfn[top[x]],dfn[x]); x = fa[top[x]]; } if(dep[x] > dep[y]) swap(x,y); ans += querys(1,1,n,dfn[x],dfn[y]); return ans; } int main() { n = read(); rep(i,1,n-1) x = read(),y = read(),add(x,y),add(y,x); rep(i,1,n) a[i] = read(); dfs1(1,0,1),dfs2(1,1),build(1,1,n); q = read(); while(q--) { scanf("%s",s); if(s[0] == 'C') x = read(),y = read(),modify(1,1,n,dfn[x],y); else if(s[1] == 'M') x = read(),y = read(),printf("%d\n",mrange(x,y)); else x = read(),y = read(),printf("%d\n",qrange(x,y)); } return 0; }
2.HAOI2015 树上操作 传送门
这题也是板子题……但是他有特别诡异的bug……好像全开longlong能A,最诡异的是我把main前面的int删了就A了??!
我这是练树剖还是练线段树啊qaq
#include<iostream> #include<cstdio> #include<cmath> #include<algorithm> #include<queue> #include<cstring> #define rep(i,a,n) for(int i = a;i <= n;i++) #define per(i,n,a) for(int i = n;i >= a;i--) #define enter putchar('\n') #define pr pair<int,int> #define mp make_pair #define fi first #define sc second using namespace std; typedef long long ll; const int M = 100005; const int N = 10000005; ll read() { ll ans = 0,op = 1; char ch = getchar(); while(ch < '0' || ch > '9') { if(ch == '-') op = -1; ch = getchar(); } while(ch >='0' && ch <= '9') { ans *= 10; ans += ch - '0'; ch = getchar(); } return ans * op; } struct seg { ll lazy,v; }t[M<<2]; struct edge { ll next,to; }e[M<<1]; ll n,m,x,y,dfn[M],rk[M],a[M],size[M],hson[M],dep[M],top[M],fa[M],head[M],ecnt,idx; void add(ll x,ll y) { e[++ecnt].to = y; e[ecnt].next = head[x]; head[x] = ecnt; } void dfs1(ll x,ll f,ll depth) { size[x] = 1,fa[x] = f,dep[x] = depth; ll maxson = -1; for(int i = head[x];i;i = e[i].next) { if(e[i].to == f) continue; dfs1(e[i].to,x,depth+1); size[x] += size[e[i].to]; if(size[e[i].to] > maxson) maxson = size[e[i].to],hson[x] = e[i].to; } } void dfs2(ll x,ll t) { dfn[x] = ++idx,rk[idx] = a[x],top[x] = t; if(!hson[x]) return; dfs2(hson[x],t); for(int i = head[x];i;i = e[i].next) { if(e[i].to == fa[x] || e[i].to == hson[x]) continue; dfs2(e[i].to,e[i].to); } } void build(ll p,ll l,ll r) { if(l == r) { t[p].v = rk[l]; return; } ll mid = (l+r) >> 1; build(p<<1,l,mid),build(p<<1|1,mid+1,r); t[p].v = t[p<<1].v + t[p<<1|1].v; } void pushdown(ll p,ll l,ll r) { ll mid = (l+r) >> 1; t[p<<1].v += (ll)(t[p].lazy * (mid-l+1)); t[p<<1|1].v += (ll)(t[p].lazy * (r-mid)); t[p<<1].lazy += t[p].lazy; t[p<<1|1].lazy += t[p].lazy; t[p].lazy = 0; } void modify(ll p,ll l,ll r,ll x,ll val) { if(l == r) { t[p].v += val; return; } if(t[p].lazy != 0) pushdown(p,l,r); ll mid = (l+r) >> 1; if(x <= mid) modify(p<<1,l,mid,x,val); else modify(p<<1|1,mid+1,r,x,val); t[p].v = t[p<<1].v + t[p<<1|1].v; } void update(ll p,ll l,ll r,ll kl,ll kr,ll val) { if(l == kl && r == kr) { t[p].v += (val * (r-l+1)); t[p].lazy += val; return; } ll mid = (l+r) >> 1; if(t[p].lazy != 0) pushdown(p,l,r); if(kr <= mid) update(p<<1,l,mid,kl,kr,val); else if(kl > mid) update(p<<1|1,mid+1,r,kl,kr,val); else update(p<<1,l,mid,kl,mid,val),update(p<<1|1,mid+1,r,mid+1,kr,val); t[p].v = t[p<<1].v + t[p<<1|1].v; } ll querys(ll p,ll l,ll r,ll kl,ll kr) { if(l == kl && r == kr) return t[p].v; ll mid = (l+r) >> 1; if(t[p].lazy != 0) pushdown(p,l,r); if(kr<= mid) return querys(p<<1,l,mid,kl,kr); else if(kl > mid) return querys(p<<1|1,mid+1,r,kl,kr); else return querys(p<<1,l,mid,kl,mid) + querys(p<<1|1,mid+1,r,mid+1,kr); } ll qrange(ll x) { ll ans = 0; while(top[x] != 1) { ans += querys(1,1,n,dfn[top[x]],dfn[x]); x = fa[top[x]]; } ans += querys(1,1,n,1,dfn[x]); return ans; } int main() { n = read(),m = read(); rep(i,1,n) a[i] = read(); rep(i,1,n-1) x = read(),y = read(),add(x,y),add(y,x); dfs1(1,0,1),dfs2(1,1),build(1,1,n); while(m--) { ll op = read(); if(op == 1) x = read(),y = read(),modify(1,1,n,dfn[x],y); if(op == 2) x = read(),y = read(),update(1,1,n,dfn[x],dfn[x]+size[x]-1,y); if(op == 3) x = read(),printf("%lld\n",qrange(x)); } }