BZOJ 4034 洛谷3178 树上操作题解

一个很裸的树链剖分模板。注意一下数据范围,有的地方要开longlong,这就是唯一的陷阱了。

# include<iostream>
# include<cstdio>
# include<cmath>
# include<algorithm>
using std::max;
using std::min;
const int mn = 100005;
typedef long long LL;
struct edge{int to,next;};
edge e[mn*2];
int edge_max,head[mn];
void add(int x,int y)
{
    e[++edge_max].to=y;
    e[edge_max].next=head[x];
    head[x]=edge_max;
}
int n,m;
int val[mn];
int fa[mn],siz[mn];
void dfs1(int x)
{
    siz[x]=1;
    for(int i=head[x];i;i=e[i].next)
    {
       int v=e[i].to;
       if(v==fa[x])
         continue;
       fa[v]=x;
       dfs1(v);
       siz[x]+=siz[v];
    }
}
int tot,id[mn],bl[mn],mx[mn];
void dfs2(int x,int chain)
{
    int k=0;
    ++tot;
    id[x]=mx[x]=tot;
    bl[x]=chain;
    for(int i=head[x];i;i=e[i].next)
    {
        if(e[i].to!=fa[x] && siz[e[i].to]>siz[k])
            k=e[i].to;
    }
    if(k==0)
        return ;
    dfs2(k,chain);
    mx[x]=max(mx[x],mx[k]);
    for(int i=head[x];i;i=e[i].next)
    {
        if(fa[x]!=e[i].to && e[i].to!=k)
        {
            dfs2(e[i].to,e[i].to);
            mx[x]=max(mx[x],mx[e[i].to]);
        }
    }
}
struct node{int l,r;LL sum,tag;};
node tr[mn*4];
inline void pushdown(int cur)
{
    if(tr[cur].tag && tr[cur].l!=tr[cur].r)
    {
        tr[cur<<1].tag+=tr[cur].tag;
        tr[cur<<1|1].tag+=tr[cur].tag;
        tr[cur<<1].sum+=(tr[cur<<1].r-tr[cur<<1].l+1)*tr[cur].tag;
        tr[cur<<1|1].sum+=(tr[cur<<1|1].r-tr[cur<<1|1].l+1)*tr[cur].tag;
        tr[cur].tag=0;
    }
}
void build(int l,int r,int cur)
{
    tr[cur].l=l,tr[cur].r=r;
    if(l==r)
    {
        return ;
    }
    int mid=l+r>>1;
    build(l,mid,cur<<1);
    build(mid+1,r,cur<<1|1);
}
void update(int cur,int l,int r,LL x)//此处要开longlong防止下面爆掉
{
    if(tr[cur].r<l || tr[cur].l>r)
        return ;
    pushdown(cur);
    if(tr[cur].l>=l && tr[cur].r<=r)
    {
        tr[cur].sum+=(tr[cur].r-tr[cur].l+1)*x;//就是这里容易爆掉
        tr[cur].tag+=x;
        return ;
    }
    update(cur<<1,l,r,x);
    update(cur<<1|1,l,r,x);
    tr[cur].sum=tr[cur<<1].sum+tr[cur<<1|1].sum;
}
LL qsum(int cur,int l,int r)
{
    if(tr[cur].r<l || tr[cur].l>r)
        return 0;
    pushdown(cur);
    if(tr[cur].l>=l && tr[cur].r<=r)
      return tr[cur].sum;
    LL tmp=0;
    tmp+=qsum(cur<<1,l,r);
    tmp+=qsum(cur<<1|1,l,r);
    tr[cur].sum=tr[cur<<1].sum+tr[cur<<1|1].sum;
    return tmp;
}
void querry(int x)
{
    LL ans=0;
    while(bl[x]!=1)
    {
        ans+=qsum(1,id[bl[x]],id[x]);
        x=fa[bl[x]];
    }
    ans+=qsum(1,1,id[x]);
    printf("%lld\n",ans);
}
int main()
{
    int opt,x,y;
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)
        scanf("%d",&val[i]);
    for(int i=1;i<n;i++)
    {
        scanf("%d%d",&x,&y);
        add(x,y);
        add(y,x);
    }
    dfs1(1);
    dfs2(1,1);
    build(1,n,1);
    for(int i=1;i<=n;i++)
        update(1,id[i],id[i],val[i]);
    for(int i=1;i<=m;i++)
    {
       scanf("%d",&opt);
       if(opt==1)
       {
           scanf("%d%d",&x,&y);
           update(1,id[x],id[x],y*1ll);
       }
       else if(opt==2)
       {
           scanf("%d%d",&x,&y);
           update(1,id[x],mx[x],y*1ll);
       }
       else {
        scanf("%d",&x);
        querry(x);
       }
    }
    return 0;
}

  

posted @ 2018-04-05 18:13  logeadd  阅读(167)  评论(0编辑  收藏  举报