bzoj4034: [HAOI2015]树上操作(树链剖分)


4034: [HAOI2015]树上操作

Time Limit: 10 Sec  Memory Limit: 256 MB
Submit: 4423  Solved: 1411
[Submit][Status][Discuss]

Description

有一棵点数为 N 的树,以点 1 为根,且树点有边权。然后有 M 个
操作,分为三种:
操作 1 :把某个节点 x 的点权增加 a 。
操作 2 :把某个节点 x 为根的子树中所有点的点权都增加 a 。
操作 3 :询问某个节点 x 到根的路径中所有点的点权和。

Input

第一行包含两个整数 N, M 。表示点数和操作数。接下来一行 N 个整数,表示树中节点的初始权值。接下来 N-1 
行每行三个正整数 fr, to , 表示该树中存在一条边 (fr, to) 。再接下来 M 行,每行分别表示一次操作。其中
第一个数表示该操作的种类( 1-3 ) ,之后接这个操作的参数( x 或者 x a ) 。

Output

对于每个询问操作,输出该询问的答案。答案之间用换行隔开。

Sample Input

5 5
1 2 3 4 5
1 2
1 4
2 3
2 5
3 3
1 2 1
3 5
2 1 2
3 3

Sample Output

6
9
13

HINT

 对于 100% 的数据, N,M<=100000 ,且所有输入数据的绝对值都不会超过 10^6 。


啊啊啊啊啊啊啊

一个long long 没转好

调了一个多小时

难怪我小数据对拍拍不出来

以后涉及long long的一定要仔细处理!!!

对于树链剖分

有一个性质:

从根节点向下对所有结点的重链进行编号,的对于一条重链来说,其编号一定是连续的,对于一个结点来说,其子树的编号一定是连续的。

所以第二次dfs时处理出以该节点为根的子树的区间左端点和右端点

就是线段数的区间修改了

#include<cstdio>
#define ll long long
#include<iostream>
const int N=100018;
struct node
{
    int l,r;
    ll col,sum;
}e[N*4];
struct edgt
{
    int to,next;
}f[N*2];
int vi[N],first[N],cnt=1,ek=0,id[N],dep[N],size[N],fa[N],left[N],right[N],top[N];
void insert(int u,int v)
{
    f[++cnt].to=v;f[cnt].next=first[u];first[u]=cnt;
    f[++cnt].to=u;f[cnt].next=first[v];first[v]=cnt;
}
void dfs1(int ro)
{
    size[ro]=1;
    for(int k=first[ro];k;k=f[k].next)
    {
        if(f[k].to==fa[ro]) continue;
        dep[f[k].to]=dep[ro]+1;
        fa[f[k].to]=ro;
        dfs1(f[k].to);
        size[ro]+=size[f[k].to];
    }
}
void dfs2(int ro,int chain)
{
    int i=0;    left[ro]=id[ro]=++ek;top[ro]=chain;
    for(int k=first[ro];k;k=f[k].next)
    if(dep[f[k].to]>dep[ro]&&size[f[k].to]>size[i])       i=f[k].to;
    if(!i)
    {
        right[ro]=ek;return;
    }
    dfs2(i,chain);
    for(int k=first[ro];k;k=f[k].next)
    if(dep[f[k].to]>dep[ro]&&f[k].to!=i) dfs2(f[k].to,f[k].to);
    right[ro]=ek;
}
void build(int ro,int l,int r)
{
    e[ro].l=l;e[ro].r=r;
    if(l==r)    return;
    int mid=(l+r)/2;
    build(2*ro,l,mid);
    build(2*ro+1,mid+1,r);
}
void dowm(int ro)
{
    if(e[ro].col)
    {
        int u=2*ro,v=u+1;
        e[u].col+=e[ro].col;
        e[v].col+=e[ro].col;
        e[u].sum+=e[ro].col*(e[u].r-e[u].l+1);
        e[v].sum+=e[ro].col*(e[v].r-e[v].l+1);
        e[ro].col=0;
    }
}
void change(int ro,int z,int y,ll u)
{
    if(z<=e[ro].l&&e[ro].r<=y)
    {
        e[ro].sum+=u*(e[ro].r-e[ro].l+1);
        e[ro].col+=u;
        return;
    }
    dowm(ro);
    int mid=(e[ro].l+e[ro].r)/2;
    if(z<=mid)   change(2*ro,z,y,u);
    if(mid<y)    change(2*ro+1,z,y,u);
    e[ro].sum=e[2*ro].sum+e[2*ro+1].sum;
}
ll sum(int ro,int z,int y)
{
    if(z<=e[ro].l&&e[ro].r<=y)
    {
        return e[ro].sum;
    }
    dowm(ro);
    ll ans=0;
    int mid=(e[ro].l+e[ro].r)/2;
    if(z<=mid)   ans+=sum(2*ro,z,y);
    if(mid<y)    ans+=sum(2*ro+1,z,y);
    return ans;
}
ll qsum(int u,int v)
{
    ll summ=0;
    while(top[u]!=top[v])
    {
        if(dep[top[u]]<dep[top[v]])  std::swap(u,v);
        summ+=sum(1,id[top[u]],id[u]);
        u=fa[top[u]];   
    }   
    if(id[u]>id[v])      std::swap(u,v);
    summ+=sum(1,id[u],id[v]);
    return summ;
}
int main()
{
    int n,u,v,m;
    scanf("%d %d",&n,&m);
    for(int i=1;i<=n;i++)    scanf("%d",&vi[i]);
    for(int i=1;i<n;i++) scanf("%d %d",&u,&v),insert(u,v);
    dfs1(1);    dfs2(1,1);
    build(1,1,n);
    long long  q;
    for(int i=1;i<=n;i++)    change(1,id[i],id[i],vi[i]);
    for(int i=1;i<=m;i++)
    {
        scanf("%d %d",&u,&v);
        if(u==1)    scanf("%lld",&q),change(1,id[v],id[v],q);       
        else if(u==2)   scanf("%lld",&q),change(1,left[v],right[v],q);
        else printf("%lld\n",qsum(1,v));
    }
    return 0;
}



posted @ 2017-05-08 21:35  Brian551  阅读(121)  评论(0编辑  收藏  举报