树链剖分

树链剖分(点)

解决:

  1. 将两个节点之间的简单路径上的点的权值加上v
  2. 求两个节点之间的简单路径上的点的权值之和
  3. 以某一节点为根节点的子树内所有的点的权值加上v
  4. 求某一节点为根节点的子树内所有的点的权值之和

思想:
将数划分成若干链,用线段树或者树状数组对这些链进行操作

重儿子:对于非叶子节点x,以x的子节点的为根的子树中节点数最多的节点。
轻儿子:对于非叶子节点x,除重儿子外其余子节点都是轻儿子。
叶子节点没有轻重儿子。
重边:连接任意两个重儿子的边
轻边:其余的边
重链:由重边连起来的路径

在dfs2中优先遍历节点重儿子,所以,某条重链上的节点的重新标号是连续的。
节点x的子树编号范围:id[x]~tot[x]+id[x]-1,所以子树更新等价于线段树区间更新。

#include<bits/stdc++.h>
using namespace std;
#define lson (i<<1)
#define rson (i<<1|1)
#define mid ((l+r)>>1)
#define lowbit(x) ((x)&(-x))
const int MAXN=1e5+8;
int N,M,R,P;
int a[MAXN],deep[MAXN],tot[MAXN],fa[MAXN],son[MAXN];
//a:点的权值
//deep:点的深度
//tot:以某点为子树的节点数量
//fa:节点父节点
//son:节点重儿子
vector<int>mp[MAXN];
int dfs1(int now,int pre,int dep){
    deep[now]=dep;
    fa[now]=pre;
    tot[now]=1;
    int max_son=-1;
    for(int i=0;i<mp[now].size();++i){
        int to=mp[now][i];
        if(to==pre)continue;
        tot[now]+=dfs1(to,now,dep+1);
        if(tot[to]>max_son){
            max_son=tot[to];
            son[now]=to;
        }
    }
    return tot[now];
}
int b[MAXN],id[MAXN],top[MAXN],cnt;
void dfs2(int now,int topfa){
    id[now]=++cnt;
    b[cnt]=a[now];
    top[now]=topfa;
    if(!son[now])return;
    dfs2(son[now],topfa);
    for(int i=0;i<mp[now].size();++i){
        int to=mp[now][i];
        if(!id[to])dfs2(to,to);
    }
}
int sum[MAXN<<2],fg[MAXN<<2];
inline void up(int i){sum[i]=(sum[lson]+sum[rson]+P)%P;}
inline void down(int i,int l,int r){
    sum[lson]=(sum[lson]+(mid-l+1)*fg[i]+P)%P;
    sum[rson]=(sum[rson]+(r-mid)*fg[i]+P)%P;
    fg[lson]=(fg[lson]+fg[i]+P)%P;
    fg[rson]=(fg[rson]+fg[i]+P)%P;
    fg[i]=0;
}
void build(int i=1,int l=1,int r=N){
    if(l==r){sum[i]=b[l];return;}
    build(lson,l,mid);
    build(rson,mid+1,r);
    up(i);
}
void interval_add(int x,int y,int v,int i=1,int l=1,int r=N){
    if(x<=l&&r<=y){
        sum[i]=(sum[i]+(r-l+1)*v+P)%P;
        fg[i]=(fg[i]+v)%P;
        return;
    }
    if(fg[i])down(i,l,r);
    if(x<=mid)interval_add(x,y,v,lson,l,mid);
    if(y>mid)interval_add(x,y,v,rson,mid+1,r);
    up(i);
}
int interval_sum(int x,int y,int i=1,int l=1,int r=N){
    if(x<=l&&r<=y)return sum[i];
    if(fg[i])down(i,l,r);
    int res=0;
    if(x<=mid)res=(res+interval_sum(x,y,lson,l,mid)+P)%P;
    if(y>mid)res=(res+interval_sum(x,y,rson,mid+1,r)+P)%P;
    return res;
}
int path_sum(int x,int y){
    int res=0;
    while(top[x]!=top[y]){
        if(deep[top[x]]<deep[top[y]])swap(x,y);
        res=(res+interval_sum(id[top[x]],id[x])+P)%P;
        x=fa[top[x]];
    }
    if(deep[x]>deep[y])swap(x,y);
    res=(res+interval_sum(id[x],id[y])+P)%P;
    return res;
}
void path_add(int x,int y,int v){
    while(top[x]!=top[y]){
        if(deep[top[x]]<deep[top[y]])swap(x,y);
        interval_add(id[top[x]],id[x],v);
        x=fa[top[x]];
    }
    if(deep[x]>deep[y])swap(x,y);
    interval_add(id[x],id[y],v);
}
void son_add(int x,int v){interval_add(id[x],tot[x]+id[x]-1,v);}
int son_sum(int x){return interval_sum(id[x],tot[x]+id[x]-1);}
int main() {
    scanf("%d%d%d%d",&N,&M,&R,&P);
    for(int i=1; i<=N; ++i)scanf("%d",a+i);
    int op,x,y,z;
    for(int i=1; i<N; ++i) {
        scanf("%d%d",&x,&y);
        mp[x].push_back(y);
        mp[y].push_back(x);
    }
    dfs1(R,0,1);
    dfs2(R,R);
    build();
    while(M--){
        scanf("%d",&op);
        if(op==1){
            scanf("%d%d%d",&x,&y,&z);
            path_add(x,y,z);
        }
        else if(op==2){
            scanf("%d%d",&x,&y);
            printf("%d\n",path_sum(x,y));
        }
        else if(op==3){
            scanf("%d%d",&x,&z);
            son_add(x,z);
        }
        else{
            scanf("%d",&x);
            printf("%d\n",son_sum(x));
        }
    }
    return 0;
}
posted @ 2020-12-19 22:45  肆之月  阅读(72)  评论(0编辑  收藏  举报