树链剖分总结

转载自zzq巨佬的树链剖分:

http://blog.csdn.net/Love_mona/article/details/79344296

蒟蒻的垂死挣扎

(以洛谷上树链剖分模板为题来介绍:[洛谷P3384] 【模板】树链剖分)

听说树剖很简单

树剖大概算一种思想吧,通过一种巧妙的方式把一棵树的节点有序地排在一个一维数组里,使得其相关询问可以通过线段树等数据结构来维护。

那么难点就在这了,这个巧妙的方法到底是什么呢?

首先引入几个概念:

  • 重结点:子树结点数目最多的结点;
  • 轻节点:父亲节点中除了重结点以外的结点;
  • 重边:父亲结点和重结点连成的边;
  • 轻边:父亲节点和轻节点连成的边;
  • 重链:由多条重边连接而成的路径;
  • 轻链:由多条轻边连接而成的路径;

其中,我们在剖分的过程中需要应用到的就是重链,重链的存在决定了其优秀的复杂度,依靠重链进行跨越就是关键所在。

 

首先是预处理部分

我们需要记录  一个点的父亲(fa)  深度(dep)  以它为根的子树的大小(siz)  它的重儿子(son)

                          它所在重链的最高点(top)  重新排列后它在树组中的位置(w)  以及重新排列后的数组(b)

第一步,我们用通常的dfs将第一行的四个元素处理,这个自己模拟没问题,实在不会看后面整体代码。

第二步,处理第二行的三个数组,同样是一遍dfs,但是我们优先走重儿子,从而让重链在重新排列后靠在一起,而对于top数组的记录就带一个参数往下传就行了,详见代码。

    void build(int k,int tp)    
    {    
        w[k]=++cnt,top[k]=tp,b[cnt]=a[k];    
        if(son[k]) build(son[k],top[k]);    
        for(RG int i=first[k];i;i=s[i].nxt)    
            if(s[i].en!=son[k]&&s[i].en!=fa[k])    
                build(s[i].en,s[i].en);    
    }    

接下来就是操作部分了

对于这一道模板题,我们需要解决的操作有如下几种,本蒟蒻采用线段树维护(前置任务 线段树模板1),具体方法列在其后。

1.修改树上两点路径上所有点的权值(区间加)2.求树上两点路径上所有点的权值和(区间查询)

这两个操作就涉及到了树剖最难的地方,因为两点之间的路径在线段树中是没有连在一起的,那么要怎样才能不重不漏地访问到相应的节点,并且保证正确的复杂度呢,这时候重链就派上用场了。我们知道,一条重链在重新排列后的数组里是靠在一起的,那么我们每次就让较深的节点往上跳,跳到它的top节点处,这样每次跳跃做一个区间处理,直到这两个点在同一条重链上,再进行最后的一个区间处理即可。(这个可以证明是log的时间复杂度)

我觉得写两个很像的函数太蠢了就强行把它们用一个转换函数缩了一下,黑科技黑科技。

    inline void Tr_Enter(int x,int y,int val,int op)//转换    
    {    
        if(w[x]>w[y]) swap(x,y);    
        if(!op) update(1,w[x],w[y],val);    
        else    ans+=Query(1,w[x],w[y]),ans%=p;    
    }    
        
    inline void Tr_Chan(int x,int y,int w,int op)   //跳跃(这个才是重点)    
    {    
        while(top[x]!=top[y])    
        {    
            if(dep[top[x]]<dep[top[y]]) swap(x,y);    
            Tr_Enter(top[x],x,w,op);    
            x=fa[top[x]];    
        }    
        Tr_Enter(x,y,w,op);    
    }    

3.修改以一个点为根的子树的权值(区间加)4.查询以一个点为根的子树的权值和(区间查询)

这两个就简单多了,因为是dfs构的树,他们理所当然就在一起的啦。所以直接进行愉快的区间操作就好了,具体看后面整体代码。

最终形态

    #include<iostream>    
    #include<cstdio>    
    #include<cstdlib>    
    #include<cmath>    
    #include<cstring>    
    #include<algorithm>    
    #define RG register    
    #define ll long long    
    #define N 100020    
    #define ls (now<<1)    
    #define rs ((now<<1)|1)    
    using namespace std;    
        
    inline ll rread()    
    {    
        ll x=0,o=1;    
        char ch=getchar();    
        while((ch>'9'||ch<'0')&&ch!='-') ch=getchar();    
        if(ch=='-') o=-1,ch=getchar();    
        while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();    
        return x*o;    
    }    
        
    int n,m,r,p,cnt,t,ans,b[N],a[N],w[N],fa[N],son[N],first[N],top[N],dep[N],siz[N];    
    struct mona{ int nxt,en;         } s[N<<2];    
    struct tree{ int l,r,w,siz,lazy; }tr[N<<2];    
        
    inline void Insert(int x,int y)    
    {    
        s[++t]=(mona) {first[x],y};    
        first[x]=t;    
    }    
        
    inline void Init()    
    {    
        n=rread(),m=rread(),r=rread(),p=rread();    
        for(RG int i=1;i<=n;i++) a[i]=rread(),siz[i]=1;    
        for(RG int i=1;i<n;i++)    
        {    
            int x=rread(),y=rread();    
            Insert(x,y),Insert(y,x);    
        }    
    }    
        
    void dfs(int fat,int k,int deep)  //预处理fa dep siz son    
    {    
        fa[k]=fat,dep[k]=deep;    
        int num=0;    
        for(RG int i=first[k];i;i=s[i].nxt)    
        {    
            int en=s[i].en;    
            if(fat==en) continue ;    
            dfs(k,en,deep+1);    
            siz[k]+=siz[en];    
            if(siz[en]>num) num=siz[en],son[k]=en;    
        }    
    }    
        
    void build(int k,int tp)    
    {    
        w[k]=++cnt,top[k]=tp,b[cnt]=a[k];    
        if(son[k]) build(son[k],top[k]);    
        for(RG int i=first[k];i;i=s[i].nxt)    
            if(s[i].en!=son[k]&&s[i].en!=fa[k])    
                build(s[i].en,s[i].en);    
    }    
    //-----------------------------------------------以上为DFS预处理部分    
    void Build(int now,int l,int r)    
    {    
        tr[now]=(tree) {l,r},tr[now].siz=r-l+1;    
        if(l==r) {tr[now].w=b[l]; return ;}    
        int mid=(l+r)/2;    
        Build(ls,l,mid),Build(rs,mid+1,r);    
        tr[now].w=(tr[ls].w+tr[rs].w+p)%p;    
    }    
        
    inline void pushdown(int now)    
    {    
        tr[ls].w+=tr[now].lazy*tr[ls].siz,tr[rs].w+=tr[now].lazy*tr[rs].siz;    
        tr[ls].lazy+=tr[now].lazy        ,tr[rs].lazy+=tr[now].lazy;    
        tr[ls].w%=p,tr[ls].lazy%=p,tr[rs].w%=p,tr[rs].lazy%=p;    
        tr[now].lazy=0;    
    }    
        
    void update(int now,int l,int r,int val)    
    {    
        if(tr[now].l>=l&&tr[now].r<=r)    
        {    
            tr[now].w+=val*tr[now].siz;    
            tr[now].lazy+=val;    
            return ;    
        }    
        pushdown(now);    
        int mid=(tr[now].l+tr[now].r)/2;    
        if(l<=mid) update(ls,l,r,val);    
        if(r>mid)  update(rs,l,r,val);    
        tr[now].w=(tr[ls].w+tr[rs].w+p)%p;    
    }    
        
    int Query(int now,int l,int r)    
    {    
        int ans=0;    
        if(tr[now].l>=l&&tr[now].r<=r) return tr[now].w;    
        pushdown(now);    
        int mid=(tr[now].l+tr[now].r)/2;    
        if(l<=mid) ans+=Query(ls,l,r);    
        if(r>mid)  ans+=Query(rs,l,r);    
        tr[now].w=(tr[ls].w+tr[rs].w+p)%p,ans%=p;    
        return ans;    
    }    
    //-----------------------------------------------以上为线段树部分    
    inline void Tr_Enter(int x,int y,int val,int op)    
    {    
        if(w[x]>w[y]) swap(x,y);    
        if(!op) update(1,w[x],w[y],val);    
        else    ans+=Query(1,w[x],w[y]),ans%=p;    
    }    
        
    inline void Tr_Chan(int x,int y,int w,int op)    
    {    
        while(top[x]!=top[y])    
        {    
            if(dep[top[x]]<dep[top[y]]) swap(x,y);    
            Tr_Enter(top[x],x,w,op);    
            x=fa[top[x]];    
        }    
        Tr_Enter(x,y,w,op);    
    }    
        
    inline void Ans()    
    {    
        for(RG int i=1;i<=m;i++)    
        {    
            int op=rread();ans=0;    
            if(op==1) { int x=rread(),y=rread(),z=rread();      Tr_Chan(x,y,z,0); }    
            if(op==2) { int x=rread(),y=rread();                Tr_Chan(x,y,0,1); }    
            if(op==3) { int x=rread(),z=rread();  update(1,w[x],w[x]+siz[x]-1,z); }    
            if(op==4) { int x=rread();           ans=Query(1,w[x],w[x]+siz[x]-1); }    
            if(op==2||op==4) printf("%d\n",ans);    
        }    
    }    
    //---------------------------------------以上为答案处理    
    int main()    
    {    
        Init();    
        dfs(0,r,0);    
        build(r,r);    
        Build(1,1,n);    
        Ans();    
    }    

 

posted @ 2018-04-06 22:34  Eternal风度  阅读(195)  评论(0编辑  收藏  举报
/*自定义地址栏logo*/