洛谷P2590 [ZJOI2008]树的统计 (树链剖分)
题目链接:
https://www.luogu.com.cn/problem/P2590
思路:
树剖板子
代码:
#include <bits/stdc++.h> #define ls node<<1 #define rs node<<1|1 #define l_q ls,l,mid #define r_q rs,mid+1,r using namespace std; const int MAXN=1e5+5; const int INF=0x7fffffff; typedef long long ll; ll n;ll val[MAXN]; int head[MAXN],tot; struct node { int to,nxt; }e[MAXN<<1]; void add(int x,int y) { e[tot].to=y,e[tot].nxt=head[x],head[x]=tot++; } void add_edge(int x,int y) { add(x,y);add(y,x); } int dep[MAXN],f[MAXN],son[MAXN],sz[MAXN]; void dfs(int u,int fa) { dep[u]=dep[fa]+1;sz[u]=1;f[u]=fa; for(int i=head[u];~i;i=e[i].nxt) { int v=e[i].to; if(v!=fa) { dfs(v,u); sz[u]+=sz[v]; if(sz[v]>sz[son[u]]) son[u]=v; } } } ll id[MAXN],cnt,top[MAXN],a[MAXN]; void dfs2(int u,int t) { top[u]=t;id[u]=++cnt;a[cnt]=val[u]; if(son[u]) dfs2(son[u],t); for(int i=head[u];~i;i=e[i].nxt) { int v=e[i].to; if(v!=f[u]&&v!=son[u]) dfs2(v,v); } } ll sum[MAXN<<2],max_[MAXN<<2],lazy[MAXN<<2]; void push_up(int node) { sum[node]=sum[ls]+sum[rs]; max_[node]=max(max_[ls],max_[rs]); } void build(int node,int l,int r) { if(l==r) { sum[node]=a[l];max_[node]=a[l];return; } int mid=(l+r)>>1; build(node<<1,l,mid);build(node<<1|1,mid+1,r); push_up(node); } void push_down(int node,int l,int r,int mid) { if(lazy[node]) { lazy[ls]+=lazy[node];lazy[rs]+=lazy[node]; sum[ls]+=(mid-l+1)*lazy[node];sum[rs]+=(r-mid)*lazy[node]; max_[ls]+=lazy[node];max_[rs]+=lazy[node]; lazy[node]=0; } } void update(int node,int l,int r,int index,int k) { if(l==r) { lazy[node]+=k;sum[node]+=(r-l+1)*k;max_[node]+=k;return; } int mid=(l+r)>>1; push_down(node,l,r,mid); if(index<=mid) update(l_q,index,k); else update(r_q,index,k); push_up(node); } ll query1(int node,int l,int r,int x,int y) { ll ans=0; if(l>=x&&r<=y) return sum[node]; int mid=(l+r)>>1; push_down(node,l,r,mid); if(x<=mid) ans+=query1(l_q,x,y); if(y>mid) ans+=query1(r_q,x,y); return ans; } ll query2(int node,int l,int r,int x,int y) { ll max1=-INF; if(l>=x&&r<=y)return max_[node]; int mid=(l+r)>>1; push_down(node,l,r,mid); if(x<=mid) max1=max(max1,query2(l_q,x,y)); if(y>mid) max1=max(max1,query2(r_q,x,y)); return max1; } ll tree_max(int x,int y) { int fx=top[x],fy=top[y];ll max1=-INF; while(fx!=fy) { if(dep[fx]<dep[fy])swap(x,y),swap(fx,fy); max1=max(max1,query2(1,1,n,id[fx],id[x])); x=f[fx];fx=top[x]; } if(id[x]>id[y])swap(x,y); max1=max(max1,query2(1,1,n,id[x],id[y])); return max1; } ll tree_sum(int x,int y) { int fx=top[x],fy=top[y];ll ans=0; while(fx!=fy) { if(dep[fx]<dep[fy])swap(x,y),swap(fx,fy); ans+=query1(1,1,n,id[fx],id[x]); x=f[fx];fx=top[x]; } if(id[x]>id[y])swap(x,y); ans+=query1(1,1,n,id[x],id[y]); return ans; } int main() { scanf("%d",&n);memset(head,-1,sizeof(head)); for(int i=1;i<n;i++) { int x,y;scanf("%d%d",&x,&y); add_edge(x,y); } for(int i=1;i<=n;i++) scanf("%lld",&val[i]); dfs(1,0);dfs2(1,1);build(1,1,n); int q;scanf("%d",&q); while(q--) { char str[100]; scanf("%s",str); ll x,y;scanf("%lld%lld",&x,&y); if(str[0]=='C') { update(1,1,n,id[x],y-a[id[x]]);a[id[x]]=y; } else { if(str[1]=='M') { printf("%lld\n",tree_max(x,y)); } else { printf("%lld\n",tree_sum(x,y)); } } } return 0; }