下笔春蚕食叶声。

笔记: 树链剖分

洛谷-树链剖分模板

前置芝士:链式前向星,线段树,dfs序

这里写的是重链剖分。

参考博客指盗图来源: x正义小学生x

树链剖分可以把一棵树“投影“到一个序列上,然后用线段树维护一些东西。

通过重儿子的性质来保证时间复杂度。


我们首先使用两次dfs进行预处理,将树投影到序列上。

对于一个有儿子的节点,我们定义它最大的儿子为重儿子。

图中,3,6,10,5,8就是重儿子。

我们称像1-3-6-10,2-5,4-8这样的为一条重链。

显然会形成很多条重链,每个点属于且只属于一条重链。

我们定义一个数组 \(top_x\) 表示 \(x\) 所在的重链的最浅节点。

在第一次dfs中,我们求出每个点的深度dep,父亲节点fa,子树大小sz,重儿子。

在第二次dfs中,我们求出每个点的dfs序(时间戳就是在新序列里的位置),并且保存新序列,建线段树。注意要保存每个树上的点 对应在序列里的位置 。称为 \(id_x\)


预处理之后,就要对付询问。

询问和修改子树很显然,是询问和修改序列 \([id[x],id[x]+sz[x]-1]\)

链怎么办呢?树链剖分,意思是将链剖开成多个(一条重链 或者 一条重链的一部分)。

设链的两头为 \(x\)\(y\)

  • \(top_x\neq top_y\) 选其中链头深度较大的重链,询问和修改它,不妨设链头深度大的是 \(x\) , 这条重链剖出来以后,\(x=fa_{top_x}\) , 继续循环,直到 \(top_x=top_y\)
  • \(top_x=top_y\) 询问和修改他们所在的重链。

务必注意update的时候两点的位置!

void updrange(int x,int y,int z){
    z=(z%p+p)%p;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        update(1,1,n,id[top[x]],id[x],z); x=fa[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    update(1,1,n,id[x],id[y],z); 
    return;
}

怎么保证时间复杂度是

  • 如果你学过dsu on tree 的话就会知道,每次从u到一个轻儿子,点个数都会减半,

​ 也就是说,根到一个点最多 \(log n\) 条轻边。

  • 从根节点到一个点,最多有 \(log n+1\) 条重链。 因为最多 \(log n\) 条轻边来分割这些重链。

线段树有一只log,剖链会剖成的条数也有一只log,一共两只log。

时间复杂度 \(O(nlog^2n)\)

#include<bits/stdc++.h>
using namespace std;
const int N=1e5+10;
typedef long long LL;
int n,m,r,p;
int e,to[N<<1],nxt[N<<1],hd[N];
int tim,sz[N],dep[N],fa[N],top[N],son[N],id[N];
LL stval[N],fnval[N];
struct pos{
    int l,r;
    LL sum,lazy;
}t[N<<2];
void add(int a,int b){
    to[++e]=b; nxt[e]=hd[a]; hd[a]=e;
}
void pushup(int rt){
    t[rt].sum=(t[rt<<1].sum+t[rt<<1|1].sum)%p;
}
void pushdown(int rt){
    if(t[rt].lazy){
        t[rt<<1].lazy=(t[rt<<1].lazy+t[rt].lazy)%p;
        t[rt<<1].sum=(t[rt<<1].sum+1ll*(t[rt<<1].r-t[rt<<1].l+1)*t[rt].lazy%p)%p;
        t[rt<<1|1].lazy=(t[rt<<1|1].lazy+t[rt].lazy)%p;
        t[rt<<1|1].sum=(t[rt<<1|1].sum+1ll*(t[rt<<1|1].r-t[rt<<1|1].l+1)*t[rt].lazy%p)%p;
        t[rt].lazy=0;
    }
}
void build(int rt,int l,int r){
    t[rt].l=l; t[rt].r=r;
    if(l==r){
        t[rt].sum=fnval[l];
        t[rt].lazy=0;
        return;
    }
    int mid=l+r>>1;
    build(rt<<1,l,mid);
    build(rt<<1|1,mid+1,r);
    pushup(rt);
}
void update(int rt,int l,int r,int L,int R,LL val){
    if(L<=l&&r<=R){
        t[rt].lazy=(t[rt].lazy+val)%p;
        t[rt].sum=(t[rt].sum+1ll*val*(t[rt].r-t[rt].l+1)%p)%p;
        return;
    }
    pushdown(rt);
    int mid=l+r>>1;
    if(L<=mid) update(rt<<1,l,mid,L,R,val);
    if(R>mid) update(rt<<1|1,mid+1,r,L,R,val);
    pushup(rt);
}
LL query(int rt,int l,int r,int L,int R){
    LL ret=0;
    if(L<=l&&r<=R) return t[rt].sum;
    pushdown(rt);
    int mid=l+r>>1;
    if(L<=mid) ret=(ret+query(rt<<1,l,mid,L,R))%p;
    if(R>mid) ret=(ret+query(rt<<1|1,mid+1,r,L,R))%p;
    pushup(rt);
    return ret;
}
void dfs1(int u,int fat){
    sz[u]=1; dep[u]=dep[fat]+1; fa[u]=fat;
    for(int i=hd[u];i;i=nxt[i]){
        int v=to[i]; if(v==fat) continue;
        dfs1(v,u);sz[u]+=sz[v];
        if(sz[v]>sz[son[u]]) son[u]=v; 
    }
    return;
}//dep,fa,子树大小(含它自己),重儿子编号son
void dfs2(int u,int topf){
    id[u]=++tim; fnval[tim]=stval[u]; top[u]=topf;
    if(!son[u]) return;
    dfs2(son[u],topf);
    for(int i=hd[u];i;i=nxt[i]){
        int v=to[i]; if(v==fa[u]||v==son[u]) continue;
        dfs2(v,v);
    }
    return;
}//新编号,赋值到新编号上,所在链的顶端,处理每条链
void updrange(int x,int y,int z){
    z=(z%p+p)%p;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        update(1,1,n,id[top[x]],id[x],z); x=fa[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    update(1,1,n,id[x],id[y],z); 
    return;
}
LL qrange(int x,int y){
    LL ret=0;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        ret=(ret+query(1,1,n,id[top[x]],id[x]))%p; x=fa[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    ret=(ret+query(1,1,n,id[x],id[y]))%p;
    return (ret+p)%p;
}
void updson(int x,int z){
    update(1,1,n,id[x],id[x]+sz[x]-1,(z%p+p)%p);
    return;
}
LL qson(int x){
    return (query(1,1,n,id[x],id[x]+sz[x]-1)+p)%p;
}
int main(){
    scanf("%d%d%d%d",&n,&m,&r,&p);
    for(int i=1;i<=n;i++)
        scanf("%lld",&stval[i]),stval[i]=(stval[i]%p+p)%p;
    for(int i=1,u,v;i<n;i++){
        scanf("%d%d",&u,&v);
        add(u,v); add(v,u);
    }
    dfs1(r,0); dfs2(r,r);
    build(1,1,n);
    for(int i=1,tp,x,y;i<=m;i++){
        LL z;
        scanf("%d",&tp);
        if(tp==1){
            scanf("%d%d%lld",&x,&y,&z); updrange(x,y,z);
        } else if(tp==2){
            scanf("%d%d",&x,&y); printf("%lld\n",qrange(x,y));
        } else if(tp==3){
            scanf("%d%lld",&x,&z); updson(x,z);
        } else{
            scanf("%d",&x); printf("%lld\n",qson(x));
        }
    }    
    return 0;
}
posted @ 2020-11-19 22:22  ACwisher  阅读(151)  评论(0编辑  收藏  举报