树链剖分
给出一棵 n 个节点的树,初始每个节点有一个点权,要求维护三种操作:
1 u w:将顶点 u 的权值修改为 w。
2 u v:询问从 u 到 v 的路径上所有顶点的权值和。
3 u v:询问从 u 到 v 的路径上最大的权值是多少。
代码:
#include<iostream> #include<cstdio> #include<cstdlib> #include<cstring> #include<algorithm> #define ll long long #define il inline #define db double #define max(a,b) ((a)>(b)?(a):(b)) #define min(a,b) ((a)<(b)?(a):(b)) using namespace std; il int gi() { int x=0,y=1; char ch=getchar(); while(ch<'0'||ch>'9') { if(ch=='-') y=-1; ch=getchar(); } while(ch>='0'&&ch<='9') { x=x*10+ch-'0'; ch=getchar(); } return x*y; } il ll gl() { ll x=0,y=1; char ch=getchar(); while(ch<'0'||ch>'9') { if(ch=='-') y=-1; ch=getchar(); } while(ch>='0'&&ch<='9') { x=x*10+ch-'0'; ch=getchar(); } return x*y; } ll point[100045];//权值 int size[100045];//子树大小 int top[100045];//该链顶 int fa[100045];//爸爸 int son[100045];//重儿子 int deep[100045];//节点深度 int num[100045],tot;//编号 int pos[100045];//编号对应的点 int head[200045],cnt; struct edge { int next,to; }e[200045]; il void add(int from,int to) { e[++cnt].next=head[from]; e[cnt].to=to; head[from]=cnt; } //bool vis[100045]; void dfs1(int x) { int r=head[x]; size[x]=1; while(r!=-1) { int now=e[r].to; if(now!=fa[x]) { deep[now]=deep[x]+1; fa[now]=x; dfs1(now); size[x]+=size[now]; if(son[x]==-1||size[now]>size[son[x]]) son[x]=now; } r=e[r].next; } } void dfs2(int x,int anc) { top[x]=anc; num[x]=++tot; pos[tot]=x; if(son[x]==-1) return; dfs2(son[x],anc);//找重链 int r=head[x]; while(r!=-1) { int now=e[r].to; if(now!=fa[x]&&now!=son[x]) dfs2(now,now);//轻链 r=e[r].next; } } struct node { ll sum,maxn; }c[1000045]; void build(int rt,int l,int r) { if(l==r) { c[rt].sum=point[pos[l]]; c[rt].maxn=point[pos[l]]; return; } if(l>r) return; int mid=(l+r)>>1; build(rt<<1,l,mid); build(rt<<1|1,mid+1,r); c[rt].sum=c[rt<<1].sum+c[rt<<1|1].sum; c[rt].maxn=max(c[rt<<1].maxn,c[rt<<1|1].maxn); } void update(int rt,int l,int r,int pos,ll NUM) { if(l==r) { c[rt].sum=NUM; c[rt].maxn=NUM; return; } if(l>r)return; int mid=(l+r)>>1; if(pos<=mid) update(rt<<1,l,mid,pos,NUM); else update(rt<<1|1,mid+1,r,pos,NUM); c[rt].sum=c[rt<<1].sum+c[rt<<1|1].sum; c[rt].maxn=max(c[rt<<1].maxn,c[rt<<1|1].maxn); } ll query(int rt,int l,int r,int L,int R) { //if(l>r) //return 0; if(L<=l&&R>=r) return c[rt].sum; if(L>r||R<l)return 0; int mid=(l+r)/2; ll sum=0; if(L<=mid) sum+=query(rt<<1,l,mid,L,R); if(R>mid) sum+=query(rt<<1|1,mid+1,r,L,R); return sum; } ll queryy(int rt,int l,int r,int L,int R) { if(L<=l&&R>=r) return c[rt].maxn; int mid=(l+r)>>1; if(L>r||R<l)return -2e9; ll r1=-2e9,r2=-2e9; if(L<=mid) r1=queryy(rt<<1,l,mid,L,R); if(R>mid) r2=queryy(rt<<1|1,mid+1,r,L,R); return max(r1,r2); } int main() { freopen("tree.in","r",stdin); freopen("tree.out","w",stdout); memset(head,-1,sizeof(head)); memset(son,-1,sizeof(son)); int n=gi(); for(int i=1;i<=n;i++) point[i]=gl(); int x,y; for(int i=1;i<n;i++) { x=gi(),y=gi(); add(x,y); add(y,x); } deep[1]=1; fa[1]=1; dfs1(1); dfs2(1,1); build(1,1,n); int m=gi(); int p; for(int i=1;i<=m;i++) { //printf("c[1].sum=%d\n",c[1].sum); p=gi(); if(p==1) { x=gi(),y=gi(); update(1,1,n,num[x],y);//点更新 } if(p==2) { x=gi(),y=gi(); ll sum=0; while(top[x]!=top[y]) { if(deep[top[x]]<deep[top[y]])//需要x链顶更深 swap(x,y); sum+=query(1,1,n,num[top[x]],num[x]);//加上这一段区间和 x=fa[top[x]];//x跳到链顶的爸爸上 } if(num[x]<num[y]) swap(x,y); sum+=query(1,1,n,num[y],num[x]);//在加上最后一条边 printf("%lld\n",sum); } if(p==3) { x=gi(),y=gi(); ll ans=-2e9; while(top[x]!=top[y]) { if(deep[top[x]]<deep[top[y]]) swap(x,y); ans=max(ans,queryy(1,1,n,num[top[x]],num[x])); x=fa[top[x]]; } if(num[x]<num[y]) swap(x,y); ans=max(ans,queryy(1,1,n,num[y],num[x])); printf("%lld\n",ans); } } return 0; }
PEACE