子树的和+换根
有一个n个节点的子树节点编号(1~n),每个点有一个权值,1为树的根,支持以下三种操作。
1.查询以某个点为根的子树上所有节点的权值之和
2.修改以某个点为根的子树上所有节点的权值
3.重新指定一个节点为树的新根
输入
第一行两个整数n m表示树的节点个数n,操作次数m。
第二行n个整数表示节点的权值
第三以下n-1行每行两个整数x,y表示x到y是一条边
以下m行,每行代表一个操作
1 u 表示查询以u为根的子树的所有节点的权值之和
2 u k表示查询以u为根的子树的所有节点的权值加k
3 u表示将u 设为树的新根
#include<bits/stdc++.h> using namespace std; const int maxn=1e6+10; struct node{ int to; int nxt; }E[maxn<<1]; int h[maxn],dfn[maxn],rnk[maxn],w[maxn],siz[maxn],f[maxn][25],dep[maxn],lg[maxn],etot,tot; struct segment{ int sum; int layz; }Seg[maxn<<2]; void addedge(int x,int y) { E[++etot].to=y; E[etot].nxt=h[x]; h[x]=etot; } void dfs(int u,int fa) { dfn[u]=++tot; rnk[tot]=u; siz[u]=1; dep[u]=dep[fa]+1; f[u][0]=fa; for (int i=1;(1<<i)<=dep[u];i++) { f[u][i]=f[f[u][i-1]][i-1]; } for (int i=h[u];i;i=E[i].nxt) { int v=E[i].to; if (v!=fa) { dfs(v,u); siz[u]+=siz[v]; } } } int lca(int x,int y) { if (dep[x]<dep[y]) swap(x,y); int d=dep[x]-dep[y]; while(d) { int k=d&(-d); x=f[x][lg[k]]; d=d-k; } if (x==y) return x; for (int i=lg[dep[x]];i>=0;i++) { if (f[x][i]!=f[y][i]) { x=f[x][i]; y=f[y][i]; } } return f[x][0]; } int fid(int x,int r) { if (dep[x]<dep[r]) swap(x,r); int d=dep[x]-dep[r]; d--; while(d) { int k=d&(-d); x=f[x][lg[k]]; d=d-k; } return x; } void pushup(int rt) { Seg[rt].sum=Seg[rt<<1].sum+Seg[rt<<1|1].sum; } void pushdown(int rt,int l,int r) { if (Seg[rt].layz) { int mid=(l+r); Seg[rt<<1].sum+=(mid-l+1)*Seg[rt].layz; Seg[rt<<1].layz+=Seg[rt].layz; Seg[rt<<1|1].sum+=(r-mid)*Seg[rt].layz; Seg[rt<<1|1].layz+=Seg[rt].layz; Seg[rt].layz=0; } } void build(int rt,int l,int r) { if (l==r) { Seg[rt].sum=w[rnk[l]]; Seg[rt].layz=0; return; } int mid=(l+r)>>1; build(rt<<1,l,mid); build(rt<<1|1,mid+1,r); pushup(rt); } void update(int rt,int l,int r,int ql,int qr ,int val) { if (ql<=l&&r<=qr) { Seg[rt].sum+=(r-l+1)*val; Seg[rt].layz+=val; return ; } pushdown(rt,l,r); int mid=(l+r)>>1; if (ql<=mid) update(rt<<1,l,mid,ql,qr,val); if (qr>mid) update(rt<<1|1,mid+1,r,ql,qr,val); pushup(rt); } int query(int rt,int l,int r,int ql,int qr) { if (ql<=l&&r<=qr) { return Seg[rt].sum; } pushdown(rt,l,r); int mid=(l+r)>>1; int ans=0; if (ql<=mid) ans+=query(rt<<1,l,mid,ql,qr); if (qr>mid) ans+=query(rt<<1|1,mid+1,r,ql,qr); return ans; } int main() { int n,m,rk; cin>>n>>m; for (int i=2;i<=n;i++) lg[i]=lg[i<<1]+1; for (int i=1;i<=n;i++) cin>>w[i]; for (int i=1;i<n;i++) { int x,y; cin>>x>>y; addedge(x,y); addedge(y,x); } dfs(1,0); build(1,1,n); rk=1; for (int i=1;i<=m;i++) { int opt,u,k; cin>>opt; if (opt==1) { cin>>u; if (u==rk) { cout<<query(1,1,n,1,n)<<endl; } else if (lca(u,rk)==u) { int ans=0; int s; s=fid(u,rk); ans=query(1,1,n,1,dfn[s]-1); if (dfn[s]+siz[s]-1!=n) ans+=query(1,1,n,dfn[s]+siz[s],n); cout<<ans<<endl; } else { cout<<query(1,1,n,dfn[u],dfn[u]+siz[u]-1)<<endl; } } else if (opt==2) { cin>>u>>k; int s; if (u==rk) { update(1,1,n,1,n,k); } else if (lca(u,rk)==u) { int s; s=fid(u,rk); update(1,1,n,1,dfn[u]-1,k); if (dfn[u]+siz[u]-1!=n) update(1,1,n,dfn[u+siz[u]],n,k); } else { update(1,1,n,dfn[u],dfn[u]+siz[u]-1,k); } } else if(opt==3) { cin>>rk; } } return 0; } /* in 6 8 1 2 3 4 5 6 1 2 1 3 1 4 4 5 5 6 1 2 1 1 2 3 4 1 2 1 1 3 4 1 4 1 1 out 2 21 25 25 10 in 6 10 1 2 3 4 5 6 1 2 1 3 1 4 4 5 5 6 1 2 1 1 2 3 4 1 2 1 1 3 4 1 4 1 1 3 3 1 1 out 2 21 25 25 10 18 */