BZOJ4034[HAOI2015]树上操作——树链剖分+线段树
题目描述
有一棵点数为 N 的树,以点 1 为根,且树点有边权。然后有 M 个
操作,分为三种:
操作 1 :把某个节点 x 的点权增加 a 。
操作 2 :把某个节点 x 为根的子树中所有点的点权都增加 a 。
操作 3 :询问某个节点 x 到根的路径中所有点的点权和。
输入
第一行包含两个整数 N, M 。表示点数和操作数。接下来一行 N 个整数,表示树中节点的初始权值。接下来 N-1
行每行三个正整数 fr, to , 表示该树中存在一条边 (fr, to) 。再接下来 M 行,每行分别表示一次操作。其中
第一个数表示该操作的种类( 1-3 ) ,之后接这个操作的参数( x 或者 x a ) 。
输出
对于每个询问操作,输出该询问的答案。答案之间用换行隔开。
样例输入
5 5
1 2 3 4 5
1 2
1 4
2 3
2 5
3 3
1 2 1
3 5
2 1 2
3 3
1 2 3 4 5
1 2
1 4
2 3
2 5
3 3
1 2 1
3 5
2 1 2
3 3
样例输出
6
9
13
9
13
提示
对于 100% 的数据, N,M<=100000 ,且所有输入数据的绝对值都不会超过 10^6 。
算是树剖的模板题了,在线段树上架树剖序,子树修改直接修改区间,查询就往上爬,边爬边求和就好。注意要开longlong!
#include<set> #include<map> #include<queue> #include<stack> #include<cmath> #include<cstdio> #include<vector> #include<cstring> #include<iostream> #include<algorithm> #define ll long long using namespace std; int son[100010]; int size[100010]; int s[100010]; int t[100010]; int q[100010]; int v[100010]; int top[100010]; int f[100010]; int head[100010]; int to[200010]; int next[200010]; ll sum[800010]; ll a[800010]; int num; int tot; int x,y; int n,m; int opt; ll ans; void add(int x,int y) { tot++; next[tot]=head[x]; head[x]=tot; to[tot]=y; } void dfs(int x,int fa) { f[x]=fa; size[x]=1; for(int i=head[x];i;i=next[i]) { if(to[i]!=fa) { dfs(to[i],x); size[x]+=size[to[i]]; if(size[son[x]]<size[to[i]]) { son[x]=to[i]; } } } } void dfs2(int x,int tp) { s[x]=++num; q[num]=x; top[x]=tp; if(son[x]) { dfs2(son[x],tp); } for(int i=head[x];i;i=next[i]) { if(to[i]!=f[x]&&to[i]!=son[x]) { dfs2(to[i],to[i]); } } t[x]=num; } void pushup(int rt) { sum[rt]=sum[rt<<1]+sum[rt<<1|1]; } void pushdown(int rt,int l,int r) { if(a[rt]!=0) { int mid=(l+r)>>1; a[rt<<1]+=a[rt]; a[rt<<1|1]+=a[rt]; sum[rt<<1]+=a[rt]*(mid-l+1); sum[rt<<1|1]+=a[rt]*(r-mid); a[rt]=0; } } void change(int rt,int l,int r,int L,int R,int v) { if(L<=l&&r<=R) { a[rt]+=1ll*v; sum[rt]+=1ll*v*(r-l+1); return ; } pushdown(rt,l,r); int mid=(l+r)>>1; if(L<=mid) { change(rt<<1,l,mid,L,R,v); } if(R>mid) { change(rt<<1|1,mid+1,r,L,R,v); } pushup(rt); } ll query(int rt,int l,int r,int L,int R) { if(L<=l&&r<=R) { return sum[rt]; } pushdown(rt,l,r); ll res=0; int mid=(l+r)>>1; if(L<=mid) { res+=query(rt<<1,l,mid,L,R); } if(R>mid) { res+=query(rt<<1|1,mid+1,r,L,R); } return res; } int main() { scanf("%d%d",&n,&m); for(int i=1;i<=n;i++) { scanf("%d",&v[i]); } for(int i=1;i<n;i++) { scanf("%d%d",&x,&y); add(x,y); add(y,x); } dfs(1,1); dfs2(1,1); for(int i=1;i<=n;i++) { change(1,1,n,s[i],s[i],v[i]); } for(int i=1;i<=m;i++) { scanf("%d",&opt); if(opt==1) { scanf("%d%d",&x,&y); change(1,1,n,s[x],s[x],y); } else if(opt==2) { scanf("%d%d",&x,&y); change(1,1,n,s[x],t[x],y); } else { scanf("%d",&x); ans=0; while(top[x]!=1) { ans+=query(1,1,n,s[top[x]],s[x]); x=f[top[x]]; } ans+=query(1,1,n,1,s[x]); printf("%lld\n",ans); } } }