让我们对这棵树进行肢解吧——树链剖分

树链剖分,顾名思义,它先通过轻重边剖分将树分为多条链,保证每个点属于且只属于一条链,然后再通过数据结构(树状数组、BST、SPLAY、线段树等)来维护每一条链。

这里我用的是线段树来维护,感觉应该算是最简单的,但这还是花了我一段时间去理解。//我觉得树链剖分讲解好的博客(https://www.cnblogs.com/ivanovcraft/p/9019090.html)

 

模板题:https://www.luogu.com.cn/problem/P3384

树链剖分,我觉得较为难的点有两个,一个是如何通过遍历这棵树得到树的重链和轻链,另一个是如何用线段树来维护链。

通过这道例题,我们来探寻其奥秘。

如何通过遍历这棵树得到树的重链和轻链?

首先来第一遍dfs,遍历这颗树,得到一些基本的东西,比如这个节点的父节点是谁 f [ ]以x为根节点的子树内所有节点的总数 size[ ]这个节点的在树里面的深度 d [ ]以及   记录当前结点的子节点   里面拥有最多子节点数   的那个子节点 son[ ]。

如图所示

 我们可见,树上的边有些是加粗的边,有些是没有加粗的边。加粗的边连起来每一个节点,我们叫做重链;反之,我们叫做轻链。

你能看出来是怎么找出来重链轻链的吗?如果当前点有很多个子节点,我们仅需看子节点下面有多少个节点,找出最多的那个,然后与这个子节点相连的边就叫重边,一直找下去,可得到树里面所以的重边,然后形成重链。

比如我们看图上的 1号节点,他子节点有3个,我们发现 4号节点下面的节点数最多,于是 1 和 4 之间的边就叫重边;

4 号节点,他子节点有3个,我们发现 9号节点下面的节点数最多,于是 4 和 9 之间的边就叫重边。

如果出现像 6 号节点这种情况,他有两个子节点,但是子节点下面的节点数都为0,也就是下面的节点数相等,那么我们可随便找一条边作为重边。

 

然后做第二遍dfs,这次我们要把重链上的节点都标记一个共同祖先(深度最低的)top [ ],然后通过优先走重链,再走轻链的方法,给每个节点标记上类似于时间戳的值 id [ ],rk数组表示当前时间戳代表的哪个节点。

 

 top搞出来有什么用呢?怎么那么像并查集那样的? 其实,top搞出来和后面的线段树操作有关,也是难点。

id又有什么用?我们可以联想一下,为什么并查集每次做完之后,都要把节点的father都改为一个共同祖先?原因就是为了加速,我们在查询两个点之间的关系时,如果不在一条重链上,我们可以直接把当前点跳到祖先那里,然后再看两者的关系,这是后面要说到的,id还有另一个妙用。

 

如何用线段树来维护链?

比如例题里面要求我们将树从x到y结点最短路径上所有节点的值都加上z。

分两种情况,

一 在同一条重链上面,

  那就好办,我们再次看上图,你会发现重链上的id值都是连续的,这说明了我们可以用线段树来维护区间值,这个好理解。

二 不在同一条重链上面,那么我们要怎么做呢?

  我们来看id值,刚刚讲到,我们在移动点的时候,可直接把当前点跳到他的共同祖先那里,跳的这个过程不能忽略,要用线段树维护,这时候维护的是一个区间(关系到>=2个点)。

  但是这只适用于当前点在一条重链上面,如果不在重链上怎么办?那么我们只能一步一步的走,走的这个过程不能忽略,要用线段树维护,这时候维护的是一个(只关系到1个点)

最终有两种情况了

     1 我们把点都移到了同一条重链上面,如何判断?看id值两者是否相等。相等说明就在同一条重链上面,那么之后处理如第一种情况

  2 我们把点移到了一条轻链上面。我们只能通过一步一步走,走到一起。

 

可能我们现在还是有点懵逼,我用一个表格来表示(依据上面那个图)

 

 可看到重链基本上涉及两个以上的区间,轻链在修改时只能类似去到一个点上面去修改。

比如我要改8 到 14 节点的值,最终改的是线段树区间里面的(2,5)和(6,6)。在程序里面操作不会直接(2,6)这么修改。

其实就一句话,涉及到轻链上面的改动或查询,一定是一个一个值的改,比如(6,6)、(7,7);而不是直接(6,7);而重链的话,可一个一个值改,也可一段一段改。

 

最后附上模板题代码:

 

#include <bits/stdc++.h>
#define maxn 1000005
using namespace std;
struct node
{
    int lazy,l,r,sum;
};
node a[maxn];
int op,x,y,z,mod,n,m,r,p,i,first[maxn],dis[maxn],next[maxn],value[maxn],zhi[maxn],tot,size[maxn],id[maxn],f[maxn],depth[maxn],son[maxn],top[maxn],cnt,rank[maxn];
void add(int x,int y)
{
    tot++;
    next[tot]=first[x];
    first[x]=tot;
    //value[tot]=v;
    zhi[tot]=y;
}
void dfs1(int x)
{
    int k;
    k=first[x],
    size[x]=1,
    depth[x]=depth[f[x]]+1;
    while (k!=-1)
    {
        if (zhi[k]!=f[x])
        {
            f[zhi[k]]=x,
            dfs1(zhi[k]),
            size[x]+=size[zhi[k]];
            if (size[son[x]]<size[zhi[k]]) son[x]=zhi[k];
        }
        k=next[k];
    }
}
void dfs2(int x,int t)
{
    top[x]=t;
    id[x]=++cnt;
    rank[cnt]=x;
    if (son[x]) dfs2(son[x],t);
    int k=first[x];
    while (k!=-1)
    {
        if (zhi[k]!=son[x] && zhi[k]!=f[x])
            dfs2(zhi[k],zhi[k]);
        k=next[k];
    }
}
void pushup(int num)
{
    a[num].sum=(a[num*2+1].sum+a[num*2].sum)%mod;
}
void pushdown(int num)
{
    if (a[num].lazy)
    {
        a[num*2].lazy=(a[num*2].lazy+a[num].lazy)%mod;
        a[num*2+1].lazy=(a[num*2+1].lazy+a[num].lazy)%mod;
        a[num*2].sum=(a[num*2].sum+(a[num*2].r-a[num*2].l+1)*a[num].lazy)%mod;
        a[num*2+1].sum=(a[num*2+1].sum+(a[num*2+1].r-a[num*2+1].l+1)*a[num].lazy)%mod;
        a[num].lazy=0;
    }
}
void build(int l,int r,int num)
{
    if (l==r)
    {
        a[num].sum=dis[rank[l]];
        a[num].l=a[num].r=l;
        return;
    }
    int mid=(l+r)>>1;
    build (l,mid,num*2),
    build (mid+1,r,num*2+1); 
    a[num].l=a[num*2].l;
    a[num].r=a[num*2+1].r;
    pushup(num);
}
void upgrade_3(int l,int r,int num,int value)
{
    if (l<=a[num].l && a[num].r<=r)
    {
        a[num].lazy=(a[num].lazy+value) % mod;
        a[num].sum=(a[num].sum+(a[num].r-a[num].l+1)*value)% mod;
        return;
    }
    pushdown(num);
    int mid=(a[num].l+a[num].r)/2;
    if (mid>=l) upgrade_3(l,r,num*2,value);
    if (mid<r) upgrade_3(l,r,num*2+1,value);
    pushup(num);
}
void upgrade_1(int x,int y,int value)
{
    while (top[x]!=top[y])
    {
        if (depth[top[x]]<depth[top[y]]) swap(x,y);
        upgrade_3(id[top[x]],id[x],1,value);
        x=f[top[x]];
    }
    if (id[x]>id[y]) swap(x,y);
    upgrade_3(id[x],id[y],1,value);
}
int query(int l,int r,int num)
{
    if (a[num].l>=l && a[num].r<=r) return a[num].sum;
    pushdown(num);
    int mid=(a[num].l+a[num].r) /2,tot=0;
    if (mid>=l) tot+=query(l,r,num*2);
    if (mid<r) tot+=query(l,r,num*2+1);
    return tot%mod;
}
int sum(int x,int y)
{
    int ans=0;
    while (top[x]!=top[y])
    {
        if (depth[top[x]]<depth[top[y]]) swap(x,y);
        ans=(ans+query(id[top[x]],id[x],1))%mod;
        x=f[top[x]];
    }
    if (id[x]>id[y]) swap(x,y);
    return (ans+query(id[x],id[y],1))%mod; 
}
int main()
{
    scanf("%d%d%d%d",&n,&m,&r,&mod);
    memset(first,-1,sizeof(first));
    for (i=1;i<=n;i++) scanf("%d",&dis[i]);
    for (i=1;i<=n-1;i++) 
    {
        scanf("%d%d",&x,&y);
        add(x,y);
        add(y,x);
    }
    cnt=0,dfs1(r),dfs2(r,r);
    build(1,n,1);
    for (i=1;i<=m;i++)
    {
        scanf("%d",&op);
        switch(op)
        {
            case 1:scanf("%d%d%d",&x,&y,&z),upgrade_1(x,y,z);break; 
            case 2:scanf("%d%d",&x,&y),printf("%d\n",sum(x,y));break;
            case 3:scanf("%d%d",&x,&z),upgrade_3(id[x],id[x]+size[x]-1,1,z);break;
            case 4:scanf("%d",&x),printf("%d\n",query(id[x],id[x]+size[x]-1,1));break;
        }
    }
    return 0;
} 
View Code

 

posted @ 2020-02-04 22:01  Y-KnightQin  阅读(160)  评论(0编辑  收藏  举报