树链剖分就是把树拆成一系列链,然后用数据结构对链进行维护。
树链剖分主要变量:
dep[x]表示x节点的深度,size[x]表示以x为根节点的树的大小,son[x]表示x的重儿子(重儿子即x的所有儿子中size最大的儿子),
fa[x]表示x的父亲,top[x]表示x所属重链的头部。
首先,dep,size,son,fa可以简单用一个dfs解决
void dfs(int x) { siz[x]=1;son[x]=0;siz[0]=0; for(int j=last[x];j;j=e[j].next) { int y=e[j].to; if(y!=fa[x]) { fa[y]=x; dep[y]=dep[x]+1; dfs(y); if(siz[y]>siz[son[x]])son[x]=y; siz[x]+=siz[y]; } } }
对于top,如果x为fa[x]的重儿子,那么top[x]=top[fa[x]],否则top[x]=x
void dfs_tree(int x,int tp) { w[x]=++z;top[x]=tp;//w[x]为x节点对应的线段树中的叶节点 if(son[x]!=0)dfs_tree(son[x],tp);else return; for(int j=last[x];j;j=e[j].next) { int y=e[j].to; if(y!=son[x]&&y!=fa[x])dfs_tree(y,y); } }
然后我们可以借助一些数据结构维护这些链,一般用线段树
显然一条重链的点,它们的w会构成一段区间[l,r]
所以,直接添加元素
for(int i=1;i<=n;i++)change(1,w[i],a[i]);//change为普通线段树更改操作
接下来,求值操作,求x到y的树上路径中的最大值
int solvemx(int x,int y) { int mx=-1e9; while(top[x]!=top[y])//让它们不停地沿着重链向上爬。 { if(dep[top[x]]<dep[top[y]])swap(x,y); mx=max(mx,querymx(1,w[top[x]],w[x]));//查找x所属重链的max x=fa[top[x]]; } if(w[x]>w[y])swap(x,y); mx=max(mx,querymx(1,w[x],w[y])); return mx; }
#include<bits/stdc++.h> #define maxn 300005 using namespace std; int siz[maxn],dep[maxn],top[maxn],fa[maxn],son[maxn],a[maxn]; int w[maxn],n,m,x,y,last[maxn],cnt,z; struct edge{ int to,next; }e[maxn]; struct tree{ int sum,mx,l,r; }tr[maxn]; void insert(int x,int y){ e[++cnt].to=y;e[cnt].next=last[x];last[x]=cnt; } void dfs(int x) { siz[x]=1;son[x]=0;siz[0]=0; for(int j=last[x];j;j=e[j].next) { int y=e[j].to; if(y!=fa[x]) { fa[y]=x; dep[y]=dep[x]+1; dfs(y); if(siz[y]>siz[son[x]])son[x]=y; siz[x]+=siz[y]; } } } void dfs_tree(int x,int tp) { w[x]=++z;top[x]=tp; if(son[x]!=0)dfs_tree(son[x],tp);else return; for(int j=last[x];j;j=e[j].next) { int y=e[j].to; if(y!=son[x]&&y!=fa[x])dfs_tree(y,y); } } void build(int x,int l,int r) { tr[x].l=l;tr[x].r=r; if(l==r)return; int mid=(l+r)>>1; build(x*2,l,mid); build(x*2+1,mid+1,r); } void change(int now,int x,int y) { int l=tr[now].l,r=tr[now].r,mid=(l+r)>>1; if(l==r){ tr[now].mx=tr[now].sum=y; return; } if(x<=mid)change(now*2,x,y);else change(now*2+1,x,y); tr[now].mx=max(tr[now*2].mx,tr[now*2+1].mx); tr[now].sum=tr[now*2].sum+tr[now*2+1].sum; } int querymx(int now,int x,int y) { int l=tr[now].l,r=tr[now].r,mid=(l+r)>>1; if(x<=l&&y>=r)return tr[now].mx; if(x>r||y<l)return -1e9; return max(querymx(now*2,x,y),querymx(now*2+1,x,y)); } int querysum(int now,int x,int y) { int l=tr[now].l,r=tr[now].r,mid=(l+r)>>1; if(x<=l&&y>=r)return tr[now].sum; if(x>r||y<l)return 0; return querysum(now*2,x,y)+querysum(now*2+1,x,y); } int solvemx(int x,int y) { int mx=-1e9; while(top[x]!=top[y]) { if(dep[top[x]]<dep[top[y]])swap(x,y); mx=max(mx,querymx(1,w[top[x]],w[x])); x=fa[top[x]]; } if(w[x]>w[y])swap(x,y); mx=max(mx,querymx(1,w[x],w[y])); return mx; } int solvesum(int x,int y) { int sum=0,bo=0; while(top[x]!=top[y]) { if(dep[top[x]]<dep[top[y]])swap(x,y); sum+=querysum(1,w[top[x]],w[x]); if(top[x]==1)bo=1; x=fa[top[x]]; } if(w[x]>w[y])swap(x,y); if(!bo)sum+=querysum(1,w[x],w[y]);//注意这一判断 return sum; } void solve() { char c[10];int x,y; scanf("%d",&m); for(int i=1;i<=m;i++) { scanf("%s%d%d",&c,&x,&y); if(c[0]=='C')a[x]=y,change(1,w[x],y); else { if(c[1]=='M')printf("%d\n",solvemx(x,y)); else printf("%d\n",solvesum(x,y)); } } } int main(){ scanf("%d",&n); for(int i=1;i<n;i++) { int x,y; scanf("%d%d",&x,&y); insert(x,y);insert(y,x); } for(int i=1;i<=n;i++)scanf("%d",&a[i]); fa[1]=1; dfs(1); dfs_tree(1,1); build(1,1,n); for(int i=1;i<=n;i++)change(1,w[i],a[i]); // for(int i=1;i<=n;i++)printf("%d %d %d\n",w[i],top[i],fa[i]); solve(); //printf("%d %d\n",querysum(1,1,3),querysum(1,4,4)); return 0; }
ac代码