树链剖分

意义:

树链剖分 就是对一棵树分成几条链,把树形变为线性,减少处理难度

概念

重儿子:对于每一个非叶子节点,它的儿子中 儿子数量最多的那一个儿子 为该节点的重儿子
轻儿子:对于每一个非叶子节点,它的儿子中 非重儿子 的剩下所有儿子即为轻儿子
叶子节点没有重儿子也没有轻儿子(因为它没有儿子。。)
重边:连接任意两个重儿子的边叫做重边
轻边:剩下的即为轻边
重链:相邻重边连起来的 连接一条重儿子 的链叫重链
对于叶子节点,若其为轻儿子,则有一条以自己为起点的长度为1的链
每一条重链以轻儿子为起点

题目大意:

给定一棵有根树,给定每个点初值。 需要处理的问题:

将树从x到y结点最短路径上所有节点的值都加上z
求树从x到y结点最短路径上所有节点的值之和
将以x为根节点的子树内所有节点值都加上z
求以x为根节点的子树内所有节点值之和

分析:

树链剖分+线段树

树剖部分:

需要数组:

int root,n,m,p;
int dfn[N],dfn2[N],fdfn[N];
int top[N],son[N],fa[N],dep[N],size[N];

1.dfs1:

目标:

①找到fa,重儿子(son)

②处理节点深度,子树大小(size)(dep[root]=1,fa[root]=-1,其实本题不固定)

void dfs1(int x,int f,int d)
{
    dep[x]=d;
    size[x]=1;
    int mx=0;
    for(int i=head[x];i;i=bian[i].nxt)
    {
        int y=bian[i].to;
        if(y==f) continue;//不能回走
        fa[y]=x;
        dfs1(y,x,d+1);
        size[x]+=size[y];
        if(size[y]>mx)
        {
            mx=size[y],son[x]=y;//记录重儿子
        }
    }
}

2.dfs2

目标:

①找到dfn,dfn2(子树结尾dfn)便于之后线段树维护区间。

②处理fdfn,记录dfnx是几号点。便于线段树build

③注意:有重儿子,先走重儿子。

结果:

dfn数组中,一棵完整的子树,其dfn也是连续的一段。每条重链也是连续的一段。这样,用线段树很方便维护树上路径的处理。

void dfs2(int x,int f)
{
    dfn[x]=++tot;
    fdfn[tot]=x;//第tot个dfn是x号
    if(!top[x]) top[x]=x;//top未赋值才能赋值
    if(son[x]) top[son[x]]=top[x],dfs2(son[x],x);//先走重儿子
    for(int i=head[x];i;i=bian[i].nxt)
    {
        int y=bian[i].to;
        if(y==son[x]||y==f) continue;
        dfs2(y,x);
    }
    dfn2[x]=tot;//回溯之前记录下子树结尾dfn
}

此处省去线段树常规操作,详见下面代码。

3.work1

利用树剖lca想法,其中一个点一边向上翻的同时,更新值。最后在同一条链上了之后,相当于已经找到了lca直接更新另一条路径。

void work1(int x,int y,int z)
{
    while(top[x]!=top[y])
    {
        if(dep[top[x]]>dep[top[y]]) swap(x,y);//dep[top]深度深的向上翻
        add(1,1,tot,dfn[top[y]],dfn[y],z);
        y=fa[top[y]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    add(1,1,tot,dfn[x],dfn[y],z);//另一边路径
}

work2同理。

4.work3,work4,利用之前记录过的dfn2,可以直接找到子树区间。直接处理即可。

void work3(int x,int z)
{
    add(1,1,tot,dfn[x],dfn2[x],z);
}
int work4(int x)
{
    int sum=0;
    sum=(sum+query(1,1,tot,dfn[x],dfn2[x]))%p;
    return sum;
}

注意事项:

1.每次dfs注意不要返祖。

2.记得取模!!!任何加减,赋值,求和都要提起注意。

3.区间add标记直接加,sum要+c×(len)必须乘区间!!(线段树不过关。。。)

4.root是原来树的根,线段树的根就是1!!(不要混了)RE无数无数无数

详见代码:

#include<bits/stdc++.h>
using namespace std;
const int N=1e5+10;
int a[N];
int root,n,m,p;
int dfn[N],dfn2[N],fdfn[N];
int top[N],son[N],fa[N],dep[N],size[N];
struct node{
    int nxt,to;
}bian[2*N];
int cnt,tot;
int head[N];
void add(int x,int y)
{
    bian[++cnt].nxt=head[x];
    bian[cnt].to=y;
    head[x]=cnt;
}
void dfs1(int x,int f,int d)
{
    dep[x]=d;
    size[x]=1;
    int mx=0;
    for(int i=head[x];i;i=bian[i].nxt)
    {
        int y=bian[i].to;
        if(y==f) continue;
        fa[y]=x;
        dfs1(y,x,d+1);
        size[x]+=size[y];
        if(size[y]>mx)
        {
            mx=size[y],son[x]=y;
        }
    }
}
void dfs2(int x,int f)
{
    dfn[x]=++tot;
    fdfn[tot]=x;
    if(!top[x]) top[x]=x;
    if(son[x]) top[son[x]]=top[x],dfs2(son[x],x);
    for(int i=head[x];i;i=bian[i].nxt)
    {
        int y=bian[i].to;
        if(y==son[x]||y==f) continue;
        dfs2(y,x);
    }
    dfn2[x]=tot;
}
//-------------------以上树剖 ----------------------------------- 
int mod(int x)
{
    while(x>=p) x-=p;
    while(x<0) x+=p;
    return x;
}
struct tree{
    int sum,add;
    #define s(x) t[x].sum
    #define ad(x) t[x].add 
}t[4*N];
void pushup(int x)
{
    s(x)=mod(s(x<<1)+s(x<<1|1));
}
void build(int x,int l,int r)
{
    if(l==r)
    {
        s(x)=mod(a[fdfn[l]]);ad(x)=0;
        return;
    }
    int mid=(l+r)>>1;
    build(x<<1,l,mid);
    build(x<<1|1,mid+1,r);
    pushup(x);
}
void pushdown(int x,int l,int r)//change sum+=ad*len
{
    int s1=x<<1,s2=x<<1|1;
    int mid=(l+r)>>1;
    ad(s1)=mod(ad(s1)+ad(x));
    s(s1)=mod(s(s1)+ad(x)*(mid-l+1));
    ad(s2)=mod(ad(s2)+ad(x));
    s(s2)=mod(s(s2)+ad(x)*(r-mid));
    ad(x)=0;
}
void add(int x,int l,int r,int L,int R,int c)
{
    if(L<=l&&r<=R)
    {
        s(x)=mod(s(x)+mod(c*(r-l+1)));
        ad(x)=mod(ad(x)+c);
        return;
    }
    pushdown(x,l,r);
    int mid=(l+r)>>1;
    if(L<=mid) add(x<<1,l,mid,L,R,c);
    if(mid<R) add(x<<1|1,mid+1,r,L,R,c);
    pushup(x);
}
int query(int x,int l,int r,int L,int R)
{
    if(L<=l&&r<=R)
    {
        return s(x);
    }
    pushdown(x,l,r);
    int mid=(l+r)>>1;
    int res=0;
    if(L<=mid) res=mod(res+query(x<<1,l,mid,L,R));
    if(mid<R) res=mod(res+query(x<<1|1,mid+1,r,L,R));
    return res;
}
//-------------------以上线段树 ----------------------------------- 
void work1(int x,int y,int z)
{
    while(top[x]!=top[y])
    {
        if(dep[top[x]]>dep[top[y]]) swap(x,y);
        add(1,1,tot,dfn[top[y]],dfn[y],z);
        y=fa[top[y]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    add(1,1,tot,dfn[x],dfn[y],z);
}
int work2(int x,int y)
{
    int sum=0;
    while(top[x]!=top[y])
    {
        if(dep[top[x]]>dep[top[y]]) swap(x,y);
        sum=(sum+query(1,1,tot,dfn[top[y]],dfn[y]))%p;
        y=fa[top[y]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    sum=(sum+query(1,1,tot,dfn[x],dfn[y]))%p;
    return sum;
}
void work3(int x,int z)
{
    add(1,1,tot,dfn[x],dfn2[x],z);
}
int work4(int x)
{
    int sum=0;
    sum=(sum+query(1,1,tot,dfn[x],dfn2[x]))%p;
    return sum;
}
int main()
{
    scanf("%d%d%d%d",&n,&m,&root,&p);
    for(int i=1;i<=n;i++)
     scanf("%d",&a[i]);
    int x,y;
    for(int i=1;i<=n-1;i++)
    {
        scanf("%d%d",&x,&y);
        add(x,y);add(y,x);
    }
    dfs1(root,-1,1);
    dfs2(root,-1);
    fa[root]=-1;

    build(1,1,tot);
    int op,z;
    while(m)
    {
        scanf("%d",&op);
        if(op==1)
        {
            scanf("%d%d%d",&x,&y,&z);
            work1(x,y,z);
        }
        else if(op==2)
        {
            scanf("%d%d",&x,&y);
            printf("%d\n",work2(x,y));
        }
        else if(op==3)
        {
            scanf("%d%d",&x,&z);
            work3(x,z);
        }
        else{
            scanf("%d",&x);
            printf("%d\n",work4(x));
        }
        m--;
    }
    return 0;
}

 

posted @ 2018-05-13 12:00  *Miracle*  阅读(209)  评论(0编辑  收藏  举报