BZOJ4127Abs——树链剖分+线段树

题目描述

给定一棵树,设计数据结构支持以下操作
1 u v d  表示将路径 (u,v) 加d
2 u v  表示询问路径 (u,v) 上点权绝对值的和

输入

第一行两个整数n和m,表示结点个数和操作数
接下来一行n个整数a_i,表示点i的权值

接下来n-1行,每行两个整数u,v表示存在一条(u,v)的边

接下来m行,每行一个操作,输入格式见题目描述

输出

对于每个询问输出答案

样例输入

4 4
-4 1 5 -2
1 2
2 3
3 4
2 1 3
1 1 4 3
2 1 3
2 3 4

样例输出

10
13
9

提示

对于100%的数据,n,m <= 10^5 且 0<= d,|a_i|<= 10^8

 

如果都是正数直接树链剖分+线段树就行了。

现在有了负数,那不是再维护一个区间正数个数就好了?显然是不够的。

因为区间修改时会把一些负数变为正数,会改变区间正数的个数,所以我们要维护区间三个值:

1、区间绝对值之和

2、区间非负数个数

3、区间最大的负数

当每次修改一个区间时如果这个区间的最大负数会变成非负数,那么说明这个区间的非负数个数会改变,因此要重构这个区间。

怎么重构呢?

对于这个区间的左右子区间,对于不需要重构的子区间下传标记,对于需要重构的子区间就递归重构下去。

因为每个数最多只会被重构一次,因此重构均摊O(nlogn)。总时间复杂度还是O(mlogn)级别。

#include<set>
#include<map>
#include<stack>
#include<queue>
#include<cmath>
#include<cstdio>
#include<vector>
#include<bitset>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
int num[800010];
int mx[800010];
ll sum[800010];
int d[100010];
int f[100010];
int son[100010];
int size[100010];
int top[100010];
int to[200010];
int tot;
int head[100010];
int s[100010];
int q[100010];
int n,m;
int x,y,z;
int opt;
int cnt;
ll a[800010];
int next[200010];
int v[100010];
int merge(int x,int y)
{
    if(x<0&&y<0)
    {
        return max(x,y);
    }
    if(x<0)
    {
        return x;
    }
    if(y<0)
    {
        return y;
    }
    return 0;
}
void add(int x,int y)
{
    tot++;
    next[tot]=head[x];
    head[x]=tot;
    to[tot]=y;
}
void dfs(int x)
{
    size[x]=1;
    d[x]=d[f[x]]+1;
    for(int i=head[x];i;i=next[i])
    {
        if(to[i]!=f[x])
        {
            f[to[i]]=x;
            dfs(to[i]);
            size[x]+=size[to[i]];
            if(size[to[i]]>size[son[x]])
            {
                son[x]=to[i];
            }
        }
    }
}
void dfs2(int x,int tp)
{
    s[x]=++cnt;
    top[x]=tp;
    q[cnt]=x;
    if(son[x])
    {
        dfs2(son[x],tp);
    }
    for(int i=head[x];i;i=next[i])
    {
        if(to[i]!=f[x]&&to[i]!=son[x])
        {
            dfs2(to[i],to[i]);
        }
    }
}
void pushup(int rt)
{
    num[rt]=num[rt<<1]+num[rt<<1|1];
    sum[rt]=sum[rt<<1]+sum[rt<<1|1];
    mx[rt]=merge(mx[rt<<1],mx[rt<<1|1]);
}
void pushdown(int rt,bool x,bool y,int l,int r)
{
    if(a[rt])
    {
        int mid=(l+r)>>1;
        if(x)
        {
            if(mx[rt<<1])
            {
                mx[rt<<1]+=a[rt];
            }
            sum[rt<<1]+=1ll*(2*num[rt<<1]-(mid-l+1))*a[rt];
            a[rt<<1]+=a[rt];
        }
        if(y)
        {
            if(mx[rt<<1|1])
            {
                mx[rt<<1|1]+=a[rt];
            }
            sum[rt<<1|1]+=1ll*(2*num[rt<<1|1]-(r-mid))*a[rt];
            a[rt<<1|1]+=a[rt];
        }
        a[rt]=0;
    }
}
void build(int rt,int l,int r)
{
    if(l==r)
    {
        if(v[q[l]]<0)
        {
            mx[rt]=v[q[l]];
        }
        else
        {
            num[rt]=1;
        }
        sum[rt]=abs(v[q[l]]);
        return ;
    }
    int mid=(l+r)>>1;
    build(rt<<1,l,mid);
    build(rt<<1|1,mid+1,r);
    pushup(rt);
}
void rebuild(int rt,int l,int r,ll c)
{
    if(l==r)
    {
        num[rt]=1;
        sum[rt]=mx[rt]+c;
        mx[rt]=0;
        return ;
    }
    int mid=(l+r)>>1;
    c+=a[rt];
    a[rt]=c;
    if(mx[rt<<1]&&mx[rt<<1]+c>=0&&mx[rt<<1|1]&&mx[rt<<1|1]+c>=0)
    {
        a[rt]=0;
        rebuild(rt<<1,l,mid,c);
        rebuild(rt<<1|1,mid+1,r,c);
    }
    else if(mx[rt<<1]&&mx[rt<<1]+c>=0)
    {
        pushdown(rt,0,1,l,r);
        rebuild(rt<<1,l,mid,c);
    }
    else if(mx[rt<<1|1]&&mx[rt<<1|1]+c>=0)
    {
        pushdown(rt,1,0,l,r);
        rebuild(rt<<1|1,mid+1,r,c);
    }
    pushup(rt);
}
void change(int rt,int l,int r,int L,int R,int c)
{
    if(L<=l&&r<=R)
    {
        if(mx[rt]+c>=0&&mx[rt])
        {
            rebuild(rt,l,r,c);
        }
        else
        {
            if(mx[rt])
            {
                mx[rt]+=c;
            }
            a[rt]+=c;
            sum[rt]+=1ll*(2*num[rt]-(r-l+1))*c;
        }
        return ;
    }
    int mid=(l+r)>>1;
    pushdown(rt,1,1,l,r);
    if(L<=mid)
    {
        change(rt<<1,l,mid,L,R,c);
    }
    if(R>mid)
    {
        change(rt<<1|1,mid+1,r,L,R,c);
    }
    pushup(rt);
}
ll query(int rt,int l,int r,int L,int R)
{
    if(L<=l&&r<=R)
    {
        return sum[rt];
    }
    pushdown(rt,1,1,l,r);
    int mid=(l+r)>>1;
    long long res=0;
    if(L<=mid)
    {
        res+=query(rt<<1,l,mid,L,R);
    }
    if(R>mid)
    {
        res+=query(rt<<1|1,mid+1,r,L,R);
    }
    return res;
}
void updata(int x,int y,int z)
{
    while(top[x]!=top[y])
    {
        if(d[top[x]]<d[top[y]])
        {
            swap(x,y);
        }
        change(1,1,n,s[top[x]],s[x],z);
        x=f[top[x]];
    }
    if(d[x]>d[y])
    {
        swap(x,y);
    }
    change(1,1,n,s[x],s[y],z);
}
ll downdata(int x,int y)
{
    ll res=0;
    while(top[x]!=top[y])
    {
        if(d[top[x]]<d[top[y]])
        {
            swap(x,y);
        }
        res+=query(1,1,n,s[top[x]],s[x]);
        x=f[top[x]];
    }
    if(d[x]>d[y])
    {
        swap(x,y);
    }
    res+=query(1,1,n,s[x],s[y]);
    return res;
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&v[i]);
    }
    for(int i=1;i<n;i++)
    {
        scanf("%d%d",&x,&y);
        add(x,y);
        add(y,x);
    }
    dfs(1);
    dfs2(1,1);
    build(1,1,n);
    while(m--)
    {
        scanf("%d",&opt);
        scanf("%d%d",&x,&y);
        if(opt==1)
        {
            scanf("%d",&z);
            updata(x,y,z);
        }
        else
        {
            printf("%lld\n",downdata(x,y));
        }
    }
}
posted @ 2018-09-25 15:58  The_Virtuoso  阅读(307)  评论(0编辑  收藏  举报