bzoj4034[HAOI2015]树上操作 树链剖分+线段树

4034: [HAOI2015]树上操作

Time Limit: 10 Sec  Memory Limit: 256 MB
Submit: 6163  Solved: 2025
[Submit][Status][Discuss]

Description

有一棵点数为 N 的树,以点 1 为根,且树点有边权。然后有 M 个

操作,分为三种:

操作 1 :把某个节点 x 的点权增加 a 。

操作 2 :把某个节点 x 为根的子树中所有点的点权都增加 a 。

操作 3 :询问某个节点 x 到根的路径中所有点的点权和。

Input

第一行包含两个整数 N, M 。表示点数和操作数。接下来一行 N 个整数,表示树中节点的初始权值。接下来 N-1 

行每行三个正整数 fr, to , 表示该树中存在一条边 (fr, to) 。再接下来 M 行,每行分别表示一次操作。其中

第一个数表示该操作的种类( 1-3 ) ,之后接这个操作的参数( x 或者 x a ) 。

Output

对于每个询问操作,输出该询问的答案。答案之间用换行隔开。

Sample Input

5 5
1 2 3 4 5
1 2
1 4
2 3
2 5
3 3
1 2 1
3 5
2 1 2
3 3

Sample Output

6
9
13

HINT

对于 100% 的数据, N,M<=100000 ,且所有输入数据的绝对值都不会超过 10^6 。

Source

鸣谢bhiaibogf提供

 

需要注意的是子树整体加减
可以发现,一棵子树一定是一段连续区间
树链剖分的时候记录in[x]和out[x], 夹在它们之间的就是x子树区间

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#define ll long long
#define ls u<<1
#define rs ls|1
#define N 100050
using namespace std;
int n,m,tot,cnt,in[N],out[N],hd[N],fa[N],val[N];
int dep[N],v[N],son[N],siz[N],tid[N],tp[N];ll sum[N<<2],lz[N<<2];
struct edge{int v,next;}e[N<<1];
void adde(int u,int v){
    e[++tot].v=v;
    e[tot].next=hd[u];
    hd[u]=tot;
}
void dfs1(int u,int pre){
    fa[u]=pre;dep[u]=dep[pre]+1;siz[u]=1;
    for(int i=hd[u];i;i=e[i].next){
        int v=e[i].v;
        if(v==pre)continue;
        dfs1(v,u);siz[u]+=siz[v];
        if(siz[v]>siz[son[u]])son[u]=v;
    }
}
void dfs2(int u,int anc){
    if(!u)return;
    tid[u]=++cnt;v[cnt]=val[u];
    in[u]=cnt;tp[u]=anc;
    dfs2(son[u],anc);
    for(int i=hd[u];i;i=e[i].next){
        int v=e[i].v;
        if(v==fa[u]||v==son[u])continue;
        dfs2(v,v);
    }
    out[u]=cnt;
}
void pushup(int u){sum[u]=sum[ls]+sum[rs];}
void pushdown(int u,int l,int r){
    if(!lz[u])return;
    int mid=l+r>>1;ll x=lz[u];
    lz[ls]+=x;lz[rs]+=x;
    sum[ls]+=x*(mid-l+1);
    sum[rs]+=x*(r-mid);
    lz[u]=0;
}
void build(int u,int l,int r){
    if(l==r){
        sum[u]=v[l];
        return;
    }
    int mid=l+r>>1;
    build(ls,l,mid);
    build(rs,mid+1,r);
    pushup(u);
}
void update(int u,int L,int R,int l,int r,int w){
    if(l<=L&&R<=r){
        sum[u]+=1ll*w*(R-L+1);
        lz[u]+=w;return;
    }
    pushdown(u,L,R);
    int mid=L+R>>1;
    if(l<=mid)update(ls,L,mid,l,r,w);
    if(r>mid)update(rs,mid+1,R,l,r,w);
    pushup(u);
}
ll query(int u,int L,int R,int l,int r){
    if(l<=L&&R<=r)return sum[u];
    pushdown(u,L,R);
    int mid=L+R>>1;ll ret=0;
    if(l<=mid)ret+=query(ls,L,mid,l,r);
    if(r>mid)ret+=query(rs,mid+1,R,l,r);
    return ret;
}
ll jump(int x,int y){
    int fx=tp[x],fy=tp[y];
    ll ret=0;
    while(fx!=fy){
        if(dep[fx]<dep[fy]){
            swap(fx,fy);
            swap(x,y);
        }
        ret+=query(1,1,cnt,tid[fx],tid[x]);
        x=fa[fx];fx=tp[x];
    }
    if(dep[x]>dep[y])swap(x,y);
    ret+=query(1,1,cnt,tid[x],tid[y]);
    return ret;
}
int main(){
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)scanf("%d",&val[i]);
    for(int i=1;i<n;i++){
        int a,b;
        scanf("%d%d",&a,&b);
        adde(a,b);adde(b,a);
    }
    dfs1(1,0);dfs2(1,1);
    build(1,1,cnt);
    int op,a,b;
    while(m--){
        scanf("%d",&op);
        if(op==1){
            scanf("%d%d",&a,&b);
            update(1,1,cnt,tid[a],tid[a],b);
        }
        if(op==2){
            scanf("%d%d",&a,&b);
            update(1,1,cnt,in[a],out[a],b);
        }
        if(op==3){
            scanf("%d",&a);
            printf("%lld\n",jump(a,1));
        }
    }
    return 0;
}
posted @ 2017-12-20 11:43  _wsy  阅读(150)  评论(0编辑  收藏  举报