bzoj4034: [HAOI2015]树上操作(树链剖分)
4034: [HAOI2015]树上操作
Time Limit: 10 Sec Memory Limit: 256 MBSubmit: 4423 Solved: 1411
[Submit][Status][Discuss]
Description
有一棵点数为 N 的树,以点 1 为根,且树点有边权。然后有 M 个
操作,分为三种:
操作 1 :把某个节点 x 的点权增加 a 。
操作 2 :把某个节点 x 为根的子树中所有点的点权都增加 a 。
操作 3 :询问某个节点 x 到根的路径中所有点的点权和。
Input
第一行包含两个整数 N, M 。表示点数和操作数。接下来一行 N 个整数,表示树中节点的初始权值。接下来 N-1
行每行三个正整数 fr, to , 表示该树中存在一条边 (fr, to) 。再接下来 M 行,每行分别表示一次操作。其中
第一个数表示该操作的种类( 1-3 ) ,之后接这个操作的参数( x 或者 x a ) 。
Output
对于每个询问操作,输出该询问的答案。答案之间用换行隔开。
Sample Input
5 5
1 2 3 4 5
1 2
1 4
2 3
2 5
3 3
1 2 1
3 5
2 1 2
3 3
1 2 3 4 5
1 2
1 4
2 3
2 5
3 3
1 2 1
3 5
2 1 2
3 3
Sample Output
6
9
13
9
13
HINT
对于 100% 的数据, N,M<=100000 ,且所有输入数据的绝对值都不会超过 10^6 。
啊啊啊啊啊啊啊
一个long long 没转好
调了一个多小时
难怪我小数据对拍拍不出来
以后涉及long long的一定要仔细处理!!!
对于树链剖分
有一个性质:
从根节点向下对所有结点的重链进行编号,的对于一条重链来说,其编号一定是连续的,对于一个结点来说,其子树的编号一定是连续的。
所以第二次dfs时处理出以该节点为根的子树的区间左端点和右端点
就是线段数的区间修改了
#include<cstdio>
#define ll long long
#include<iostream>
const int N=100018;
struct node
{
int l,r;
ll col,sum;
}e[N*4];
struct edgt
{
int to,next;
}f[N*2];
int vi[N],first[N],cnt=1,ek=0,id[N],dep[N],size[N],fa[N],left[N],right[N],top[N];
void insert(int u,int v)
{
f[++cnt].to=v;f[cnt].next=first[u];first[u]=cnt;
f[++cnt].to=u;f[cnt].next=first[v];first[v]=cnt;
}
void dfs1(int ro)
{
size[ro]=1;
for(int k=first[ro];k;k=f[k].next)
{
if(f[k].to==fa[ro]) continue;
dep[f[k].to]=dep[ro]+1;
fa[f[k].to]=ro;
dfs1(f[k].to);
size[ro]+=size[f[k].to];
}
}
void dfs2(int ro,int chain)
{
int i=0; left[ro]=id[ro]=++ek;top[ro]=chain;
for(int k=first[ro];k;k=f[k].next)
if(dep[f[k].to]>dep[ro]&&size[f[k].to]>size[i]) i=f[k].to;
if(!i)
{
right[ro]=ek;return;
}
dfs2(i,chain);
for(int k=first[ro];k;k=f[k].next)
if(dep[f[k].to]>dep[ro]&&f[k].to!=i) dfs2(f[k].to,f[k].to);
right[ro]=ek;
}
void build(int ro,int l,int r)
{
e[ro].l=l;e[ro].r=r;
if(l==r) return;
int mid=(l+r)/2;
build(2*ro,l,mid);
build(2*ro+1,mid+1,r);
}
void dowm(int ro)
{
if(e[ro].col)
{
int u=2*ro,v=u+1;
e[u].col+=e[ro].col;
e[v].col+=e[ro].col;
e[u].sum+=e[ro].col*(e[u].r-e[u].l+1);
e[v].sum+=e[ro].col*(e[v].r-e[v].l+1);
e[ro].col=0;
}
}
void change(int ro,int z,int y,ll u)
{
if(z<=e[ro].l&&e[ro].r<=y)
{
e[ro].sum+=u*(e[ro].r-e[ro].l+1);
e[ro].col+=u;
return;
}
dowm(ro);
int mid=(e[ro].l+e[ro].r)/2;
if(z<=mid) change(2*ro,z,y,u);
if(mid<y) change(2*ro+1,z,y,u);
e[ro].sum=e[2*ro].sum+e[2*ro+1].sum;
}
ll sum(int ro,int z,int y)
{
if(z<=e[ro].l&&e[ro].r<=y)
{
return e[ro].sum;
}
dowm(ro);
ll ans=0;
int mid=(e[ro].l+e[ro].r)/2;
if(z<=mid) ans+=sum(2*ro,z,y);
if(mid<y) ans+=sum(2*ro+1,z,y);
return ans;
}
ll qsum(int u,int v)
{
ll summ=0;
while(top[u]!=top[v])
{
if(dep[top[u]]<dep[top[v]]) std::swap(u,v);
summ+=sum(1,id[top[u]],id[u]);
u=fa[top[u]];
}
if(id[u]>id[v]) std::swap(u,v);
summ+=sum(1,id[u],id[v]);
return summ;
}
int main()
{
int n,u,v,m;
scanf("%d %d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&vi[i]);
for(int i=1;i<n;i++) scanf("%d %d",&u,&v),insert(u,v);
dfs1(1); dfs2(1,1);
build(1,1,n);
long long q;
for(int i=1;i<=n;i++) change(1,id[i],id[i],vi[i]);
for(int i=1;i<=m;i++)
{
scanf("%d %d",&u,&v);
if(u==1) scanf("%lld",&q),change(1,id[v],id[v],q);
else if(u==2) scanf("%lld",&q),change(1,left[v],right[v],q);
else printf("%lld\n",qsum(1,v));
}
return 0;
}