算法:树链剖分
去年就看过树链剖分的视频了,当时连树状数组,线段树都没学,对树的 dfs 也一知半解,所以基本完全听不懂。昨天又重新看了一般,感觉思路挺简单,应该比线段树简单吧,从用树链剖分求 LCA 来看确实是这样的,但是没有想到的是用线段树维护树链剖分。QAQ 这应该是我打过最长的代码吧!(3K)
树链剖分
只讲一下基本思路,重在感性理解。
这应该是最典型的图了吧。
然后按遍历顺序把它们变成一个线性的数组
1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
---|---|---|---|---|---|---|---|
1 | 3 | 6 | 8 | 7 | 2 | 5 | 4 |
然后就变成这样一个数组了,感性理解,把剖开的链按遍历顺序拼在一起
为什么要这么做,应为我们要用高级数据结构来维护,比如线段树
可以实现查询并更改在在树上的任意两点路径的权值和任意子树的权值。
例题:树链剖分 - 重链剖分
AC 代码:
点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define lc u<<1
#define rc u<<1|1
typedef long long LL;
const int N = 100010;
int n,m,root,p;
int w[N];
vector<int >e[N];
int fa[N],dep[N],sz[N],son[N];
int top[N],id[N],nw[N],cnt;
void dfs1(int u){//处理fa,deo,sz,son
dep[u]=dep[fa[u]]+1;
sz[u]=1;
for(auto v:e[u]){
if(v==fa[u]) continue;
fa[v]=u;
dfs1(v);
sz[u]+=sz[v];
if(sz[son[u]<sz[v]]) son[u]=v;
}
}
void dfs2(int u,int t){//处理top,id,nw
top[u]=t;id[u]=++cnt;
nw[cnt]=w[u];
if(!son[u]) return;
dfs2(son[u],t);
for(auto v:e[u]){
if(v==fa[u]||v==son[u]) continue;
dfs2(v,v);
}
}
//-----------------------------------------------
struct tree{
int l,r;
LL add,sum;
}tr[N*4];
void pushup(int rt){
tr[rt].sum=tr[rt<<1].sum+tr[rt<<1|1].sum;
}
void build(int rt,int l,int r){
tr[rt]={l,r,0,nw[r]};
if(l==r) return;
int mid=(l+r)>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
pushup(rt);
}
void pushdown(int u){
if(tr[u].add){
tr[lc].sum+=tr[u].add*(tr[lc].r-tr[lc].l+1);
tr[rc].sum+=tr[u].add*(tr[rc].r-tr[rc].l+1);
tr[lc].add+=tr[u].add;
tr[rc].add+=tr[u].add;
tr[u].add=0;
}
}
void update(int rt,int l,int r,int k){//线段树上更新
if(l<=tr[rt].l&&r>=tr[rt].r){
tr[rt].add+=k;
tr[rt].sum+=k*(tr[rt].r-tr[rt].l+1);
return ;
}
pushdown(rt);//不能完全包含更改区间,先把懒标记下传
int mid=tr[rt].l+tr[rt].r>>1;
if(l<=mid) update(rt<<1,l,r,k);
if(r>mid) update(rt<<1|1,l,r,k);
pushup(rt);
}
void update_path(int x,int y,int z){
while(top[x]!=top[y]){
if(dep[top[x]]>dep[top[y]]) swap(x,y);
update(1,id[top[y]],id[y],z);
y=fa[top[y]];
}
if(dep[x]>dep[y]) swap(x,y);
update(1,id[x],id[y],z);
}
void update_tree(int u,int k){
update(1,id[u],id[u]+sz[u]-1,k);//注意减一
}
LL query(int u,int l,int r){//线段树上查询
if(l<=tr[u].l&&tr[u].r<=r){
return tr[u].sum;
}
pushdown(u);
int mid=tr[u].l+tr[u].r>>1;
LL res=0;
if(l<=mid) res+=query(lc,l,r);
if(r>mid) res+=query(rc,l,r);
return res;
}
LL query_path(int x,int y){
LL res=0;
while(top[x]!=top[y]){
if(dep[top[x]]>dep[top[y]]) swap(x,y);
res+=query(1,id[top[y]],id[y]);
y=fa[top[y]];
}
if(dep[x]>dep[y]) swap(x,y);
res+=query(1,id[x],id[y]);
return res;
}
LL query_tree(int u){
return query(1,id[u],id[u]+sz[u]-1);
}
int main(){
scanf("%d%d%d%d",&n,&m,&root,&p);
for(int i=1;i<=n;i++){
scanf("%d",&w[i]);
}
for(int i=1,x,y;i<n;i++){
scanf("%d%d",&x,&y);
e[x].emplace_back(y);
e[y].emplace_back(x);
}
// cout<<"finish cin"<<"\n";
fa[root]=0;
dfs1(root);
dfs2(root,root);
// cout<<"finish dfs"<<"\n";
build(1,1,n);
// cout<<"finish build"<<"\n";
while(m--){
int op,x,y,z;
scanf("%d%d",&op,&x);
if(op==1){
scanf("%d%d",&y,&z);
update_path(x,y,z);
}else if(op==2){
scanf("%d",&y);
printf("%d\n",query_path(x,y)%p);
}else if(op==3){
scanf("%d",&z);
update_tree(x,z);
}else if(op==4){
printf("%d\n",query_tree(x)%p);
}
}
return 0;
}
先这样,代码越长越容易寄!