BZOJ.4034 [HAOI2015]树上操作 ( 点权树链剖分 线段树 )

BZOJ.4034 [HAOI2015]树上操作 ( 点权树链剖分 线段树 )

题意分析

有一棵点数为 N 的树,以点 1 为根,且树点有边权。然后有 M 个
操作,分为三种:
操作 1 :把某个节点 x 的点权增加 a 。
操作 2 :把某个节点 x 为根的子树中所有点的点权都增加 a 。
操作 3 :询问某个节点 x 到根的路径中所有点的点权和。

首先上来两次dfs树链剖分问题不大,然后主要是这三个操作分别如何取实现。
先说第一种,点权修改,直接用线段树的点更新就好了。第二个子树修改,貌似树链剖分不能解决子树问题,但是仔细想一下其实是可以的,因为对任意一个节点x,其子树的节点编号都是连续的,原因是在dfs的时候,我们优先遍历这个节点的重儿子形成重链,其次遍历轻儿子形成轻链,所以他的儿子们是连续存储的。而需要修改的区间是[newid[x],newid[x]+size[x]-1],这样利用线段树的区间更新就可已解决了。最后一个询问操作,利用线段树的区间和,完美解决。

直的注意的一点是,在点更新的时候,别忘记有PushDown操作,因为已经打上了lazy标记

代码总览

#include <bits/stdc++.h>
#define ll long long
#define nmax 200820
using namespace std;
int fa[nmax],son[nmax],sz[nmax],newid[nmax],hashback[nmax],dep[nmax],top[nmax],newout[nmax];
int num,tot,head[nmax];
ll data[nmax];
struct edge{
    int to;
    int next;
}edg[nmax<<2];
struct tre{
    int l,r;
    ll val,lazy;
    int mid(){
        return (l+r)>>1;
    }
}tree[nmax<<2];
void add(int u, int v){
    edg[tot].to = v;
    edg[tot].next = head[u];
    head[u] = tot++;
}
void dfsFirst(int rt, int f,int d){
    dep[rt] = d;
    fa[rt] = f;
    sz[rt] = 1;
    for(int i = head[rt]; i!= -1; i = edg[i].next){
        int nxt = edg[i].to;
        if(nxt != f){
            dfsFirst(nxt,rt,d+1);
            sz[rt]+=sz[nxt];
            if(son[rt] == -1 || sz[nxt] > sz[son[rt]]){
                son[rt] = nxt;
            }
        }
    }
}
void dfsSecond(int rt, int tp){
    top[rt] = tp;
    newid[rt] = ++num;
    hashback[num] = rt;
    if(son[rt] == -1) return;
    dfsSecond(son[rt],tp);
    for(int i = head[rt];i != -1; i = edg[i].next){
        int nxt = edg[i].to;
        if(nxt != son[rt] && nxt != fa[rt])
            dfsSecond(nxt,nxt);
    }
}
void init(){
    memset(tree,0,sizeof tree);
    memset(head,-1,sizeof head);
    memset(son,-1,sizeof son);
    memset(edg,0,sizeof edg);
    memset(hashback,0,sizeof hashback);
    memset(data,0,sizeof data);
    memset(newid,0,sizeof newid);
    tot = num = 0;
}
void PushUp(int rt){
    tree[rt].val = tree[rt<<1].val + tree[rt<<1|1].val;
}
void PushDown(int rt){
    if(tree[rt].lazy){
        tree[rt<<1].lazy += tree[rt].lazy;
        tree[rt<<1|1].lazy += tree[rt].lazy;
        tree[rt<<1].val += tree[rt].lazy*(ll)(tree[rt<<1].r - tree[rt<<1].l + 1);
        tree[rt<<1|1].val +=tree[rt].lazy*(ll)(tree[rt<<1|1].r - tree[rt<<1|1].l + 1);
        tree[rt].lazy = 0;
    }
}
void Build(int l, int r, int rt){
    tree[rt].l = l; tree[rt].r = r;
    if(l == r){
        tree[rt].val = data[hashback[l]];
        return;
    }
    Build(l,tree[rt].mid(),rt<<1);
    Build(tree[rt].mid()+1,r,rt<<1|1);
    PushUp(rt);
}
void UpdatePoint(ll val, int pos, int rt){
    if(tree[rt].l == tree[rt].r){
        tree[rt].val += (ll)val;
        return;
    }
    PushDown(rt);
    if(pos <= tree[rt].mid()) UpdatePoint(val,pos,rt<<1);
    else UpdatePoint(val,pos,rt<<1|1);
    PushUp(rt);
}
void UpdateInterval(ll val, int l, int r, int rt){
    if(tree[rt].l >r || tree[rt].r < l) return;
    if(tree[rt].l >= l && tree[rt].r <= r){
        tree[rt].val += val*(ll)(tree[rt].r - tree[rt].l +1);
        tree[rt].lazy += val;
        return;
    }
    PushDown(rt);
    UpdateInterval(val,l,r,rt<<1) ;
    UpdateInterval(val,l,r,rt<<1|1);
    PushUp(rt);
}
ll QuerySUM(int l,int r,int rt){

    if(l>tree[rt].r || r<tree[rt].l) return 0;
    PushDown(rt);
    if(l <= tree[rt].l && tree[rt].r <= r) return tree[rt].val;
    return QuerySUM(l,r,rt<<1) + QuerySUM(l,r,rt<<1|1);
}
ll Find_SUM(int x, int y){
    int tx = top[x],ty =top[y];
    ll ans = 0ll;
    while(tx != ty){
        if(dep[tx] < dep[ty]){
            swap(x,y);
            swap(tx,ty);
        }
        ans += QuerySUM(newid[tx],newid[x],1);
        x = fa[tx]; tx = top[x];
    }
    if(dep[x] > dep[y]) swap(x,y);
    ans += QuerySUM(newid[x],newid[y],1);
    return ans;
}

int n,m;
int main()
{
    //freopen("in.txt","r",stdin);

    while(scanf("%d %d",&n,&m)!=EOF){
        init();
        for(int i =1;i<=n;++i) scanf("%lld", &data[i]);
        int u,v,x,y;
        int op;
        for(int i =1;i<=n-1;++i){
            scanf("%d %d",&u,&v);
            add(u,v);
            add(v,u);
        }
        dfsFirst(1,0,1);
        dfsSecond(1,1);
        Build(1,n,1);
        ll val;
        for(int i = 0;i<m;++i){
            scanf("%d",&op);
            if(op == 1){// x add y
                scanf("%d %lld",&x,&val);
                UpdatePoint(val,newid[x],1);
            }else if(op == 2){//x root add y
                scanf("%d %lld",&x,&val);
                UpdateInterval(val,newid[x],newid[x]+sz[x]-1,1);
            }else{// query (1,x)
                scanf("%d",&x);
                printf("%lld\n",Find_SUM(1,x));
            }

        }
    }
    return 0;
}
posted @ 2017-08-14 10:55  pengwill  阅读(170)  评论(0编辑  收藏  举报