P3833 [SHOI2012]魔法树 (树链剖分模板题)
题目链接:https://www.luogu.org/problem/P3833
题目大意:有一颗含有n个节点的树,初始时每个节点的值为0,有以下两种操作:
1.Add u v d表示将点u和v之间的路径上的所有节点的值都加上d。
2.Query u表示当前果树中,以点u为根的子树中,总共有多少个果子?
解题思路:树链剖分板子,具体看代码注释
代码:
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int maxn=100005; int tot,cnt,head[maxn],n,m,v[maxn]; ll tree[maxn*4],lazy[maxn*4]; int d[maxn],size[maxn],son[maxn],id[maxn],rk[maxn],fa[maxn],top[maxn]; //size[i]是i的子树的节点数,top[i]是i所在重链的顶端,son[i]是i的重儿子编号 //d[i]是节点i的真深度,id[i]是i的新编号,fa[i]是i的父亲编号 struct Edge{ int v,next; }edge[maxn<<1]; void add(int u,int v){ edge[tot].v=v; edge[tot].next=head[u]; head[u]=tot++; } void dfs1(int u,int pre){ //第一遍dfs求每个节点的深度,父亲节点,所在子树大小 d[u]=d[pre]+1; fa[u]=pre; size[u]=1; for(int i=head[u];i!=-1;i=edge[i].next){ int v=edge[i].v; if(v!=pre){ dfs1(v,u); size[u]+=size[v]; if(size[son[u]]<size[v]) son[u]=v; } } } void dfs2(int u,int tp){ //求每个节点新编号和重儿子和所在重链顶端 top[u]=tp,id[u]=++cnt,rk[cnt]=u; if(son[u]) dfs2(son[u],tp); //先跑重儿子节点保证一条重链上各个节点dfs序连续,以便可以通过线段树来维护一条重链的信息 for(int i=head[u];i!=-1;i=edge[i].next){ int v=edge[i].v; if(v!=fa[u]&&v!=son[u]) dfs2(v,v); } } void pushup(int rt){ tree[rt]=tree[rt<<1]+tree[rt<<1|1]; } void pushdown(int l,int r,int rt){ if(lazy[rt]){ tree[rt<<1]=tree[rt<<1]+lazy[rt]*l; tree[rt<<1|1]=tree[rt<<1|1]+lazy[rt]*r; lazy[rt<<1]+=lazy[rt]; lazy[rt<<1|1]+=lazy[rt]; lazy[rt]=0; } } void build(int l,int r,int rt){ lazy[rt]=0; if(l==r){ tree[rt]=v[rk[l]]; return; } int mid=l+r>>1; build(l,mid,rt<<1); build(mid+1,r,rt<<1|1); pushup(rt); } void update(int L,int R,int val,int l,int r,int rt){ if(L<=l&&R>=r){ tree[rt]+=1ll*(r-l+1)*val; lazy[rt]+=val; return; } int mid=l+r>>1; pushdown(mid-l+1,r-mid,rt); if(mid>=L) update(L,R,val,l,mid,rt<<1); if(mid<R) update(L,R,val,mid+1,r,rt<<1|1); pushup(rt); } ll query(int L,int R,int l,int r,int rt){ if(L<=l&&R>=r) return tree[rt]; int mid=l+r>>1; ll res=0; pushdown(mid-l+1,r-mid,rt); if(mid>=L) res+=query(L,R,l,mid,rt<<1); if(mid<R) res+=query(L,R,mid+1,r,rt<<1|1); return res; } void updates(int x,int y,int val){ //修改x到y最短路径的值 while(top[x]!=top[y]){ //不在同一条重链就向上跳 if(d[top[x]]<d[top[y]]) swap(x,y); update(id[top[x]],id[x],val,1,n,1); x=fa[top[x]]; //跳到该重链顶节点的父节点 } if(id[x]>id[y]) swap(x,y); //两个节点在同一条重链,但可能不是同一个节点 update(id[x],id[y],val,1,n,1); } ll ask(int x,int y){ //询问x到y最短路径的值 ll res=0; while(top[x]!=top[y]){ if(d[top[x]]<d[top[y]]) swap(x,y); res+=query(id[top[x]],id[x],1,n,1); x=fa[top[x]]; } if(id[x]>id[y]) swap(x,y); res+=query(id[x],id[y],1,n,1); return res; } int main(){ scanf("%d",&n); memset(head,-1,sizeof(head)); for(int i=1;i<n;i++){ int u,v; scanf("%d%d",&u,&v); u++; v++; add(u,v); } scanf("%d",&m); cnt=0,tot=0; dfs1(1,0),dfs2(1,1); build(1,n,1); while(m--){ char op[10]; int l,r,rt,val; scanf("%s",op); if(op[0]=='A'){ scanf("%d%d%d",&l,&r,&val); l++; r++; updates(l,r,val); } else{ scanf("%d",&rt); rt++; printf("%lld\n",query(id[rt],id[rt]+size[rt]-1,1,n,1)); //因为一个子树的大小是size[rt],起点编号是从id[rt]所以是从id[rt]到id[rt]+size[rt]-1 } } return 0; }