bzoj 1036: [ZJOI2008]树的统计Count——树链剖分
Description
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成
一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I
II. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身
Input
输入的第一行为一个整数n,表示节点的个数。接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有
一条边相连。接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来1行,为一个整数q,表示操作
的总数。接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。
Output
对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。
Sample Input
4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
Sample Output
4
1
2
2
10
6
5
6
5
16
1
2
2
10
6
5
6
5
16
————————————————————————————
这题就是典型的树上路径取max 路径求和 单点修改了
算法没什么好说的 写了三种写法
1——树链剖分(线段树版)
#include<cstdio> #include<cstring> #include<algorithm> using std::swap; const int M=50007; int read(){ int ans=0,f=1,c=getchar(); while(c<'0'||c>'9'){if(c=='-') f=-1; c=getchar();} while(c>='0'&&c<='9'){ans=ans*10+(c-'0'); c=getchar();} return ans*f; } int max(int x,int y){return x>y?x:y;} char ch[5]; int n,m; int first[M],cnt=1; struct node{int to,next;}e[2*M]; void ins(int a,int b){e[++cnt]=(node){b,first[a]}; first[a]=cnt;} void insert(int a,int b){ins(a,b); ins(b,a);} int fa[M],sz[M],top[M],son[M],id[M],idp=1,dep[M]; void f1(int x){ sz[x]=1; for(int i=first[x];i;i=e[i].next){ int now=e[i].to; if(now==fa[x]) continue; fa[now]=x; dep[now]=dep[x]+1; f1(now); sz[x]+=sz[now]; if(sz[now]>sz[son[x]]) son[x]=now; } } void f2(int x,int tp){ top[x]=tp; id[x]=idp++; if(son[x]) f2(son[x],tp); for(int i=first[x];i;i=e[i].next){ int now=e[i].to; if(now!=fa[x]&&now!=son[x]) f2(now,now); } } struct pos{int l,r,sum,mx;}tr[4*M]; void build(int x,int l,int r){ tr[x].l=l; tr[x].r=r; if(l==r) return ; int mid=(l+r)>>1; build(x<<1,l,mid); build(x<<1^1,mid+1,r); } void up(int x){ tr[x].sum=tr[x<<1].sum+tr[x<<1^1].sum; tr[x].mx=max(tr[x<<1].mx,tr[x<<1^1].mx); } void modify(int x,int s,int y){ if(tr[x].l==tr[x].r) return void(tr[x].sum=tr[x].mx=y); int mid=(tr[x].l+tr[x].r)>>1; if(mid>=s) modify(x<<1,s,y); else modify(x<<1^1,s,y); up(x); } int push_max(int x,int L,int R){ if(L<=tr[x].l&&tr[x].r<=R) return tr[x].mx; int mid=(tr[x].l+tr[x].r)>>1,ans=-1e8; if(L<=mid) ans=max(ans,push_max(x<<1,L,R)); if(R>mid) ans=max(ans,push_max(x<<1^1,L,R)); return ans; } int Qmax(int x,int y){ int ans=-1e8; while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]) swap(x,y); ans=max(ans,push_max(1,id[top[x]],id[x])); x=fa[top[x]]; } if(id[x]>id[y]) swap(x,y); ans=max(ans,push_max(1,id[x],id[y])); return ans; } int push_sum(int x,int L,int R){ if(L<=tr[x].l&&tr[x].r<=R) return tr[x].sum; int mid=(tr[x].l+tr[x].r)>>1,sum=0; if(L<=mid) sum+=push_sum(x<<1,L,R); if(R>mid) sum+=push_sum(x<<1^1,L,R); return sum; } int Qsum(int x,int y){ int sum=0; while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]) swap(x,y); sum+=push_sum(1,id[top[x]],id[x]); x=fa[top[x]]; } if(id[x]>id[y]) swap(x,y); sum+=push_sum(1,id[x],id[y]); return sum; } int main(){ int x,y; n=read(); for(int i=1;i<n;i++) x=read(),y=read(),insert(x,y); f1(1); f2(1,1); build(1,1,n); for(int i=1;i<=n;i++) x=read(),modify(1,id[i],x); m=read(); for(int i=1;i<=m;i++){ scanf("%s",ch); x=read(); y=read(); if(ch[0]=='C') modify(1,id[x],y); else if(ch[1]=='M') printf("%d\n",Qmax(x,y)); else printf("%d\n",Qsum(x,y)); } return 0; }
2——树链剖分(zkw线段树版)
#include<cstdio> #include<cstring> #include<algorithm> using std::swap; const int M=50007; int read(){ int ans=0,f=1,c=getchar(); while(c<'0'||c>'9'){if(c=='-') f=-1; c=getchar();} while(c>='0'&&c<='9'){ans=ans*10+(c-'0'); c=getchar();} return ans*f; } int max(int x,int y){return x>y?x:y;} char ch[5]; int n,m; int first[M],cnt=1; struct node{int to,next;}e[2*M]; void ins(int a,int b){e[++cnt]=(node){b,first[a]}; first[a]=cnt;} void insert(int a,int b){ins(a,b); ins(b,a);} int fa[M],sz[M],top[M],son[M],id[M],idp=1,dep[M]; void f1(int x){ sz[x]=1; for(int i=first[x];i;i=e[i].next){ int now=e[i].to; if(now==fa[x]) continue; fa[now]=x; dep[now]=dep[x]+1; f1(now); sz[x]+=sz[now]; if(sz[now]>sz[son[x]]) son[x]=now; } } void f2(int x,int tp){ top[x]=tp; id[x]=idp++; if(son[x]) f2(son[x],tp); for(int i=first[x];i;i=e[i].next){ int now=e[i].to; if(now!=fa[x]&&now!=son[x]) f2(now,now); } } int ly,s[3*M],mx[3*M]; void modify(int x,int w){ s[x+ly]=w; mx[x+ly]=w; for(x=(x+ly)>>1;x;x>>=1) s[x]=s[x<<1]+s[x<<1^1],mx[x]=max(mx[x<<1],mx[x<<1^1]); } int push_max(int l,int r){ int ans=-1e8; for(l=l+ly-1,r=r+ly+1;r-l!=1;l>>=1,r>>=1){ if(~l&1) ans=max(ans,mx[l^1]); if(r&1) ans=max(ans,mx[r^1]); } return ans; } int Qmax(int x,int y){ int ans=-1e8; while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]) swap(x,y); ans=max(ans,push_max(id[top[x]],id[x])); x=fa[top[x]]; } if(id[x]>id[y]) swap(x,y); ans=max(ans,push_max(id[x],id[y])); return ans; } int push_sum(int l,int r){ int sum=0; for(l=l+ly-1,r=r+ly+1;r-l!=1;l>>=1,r>>=1){ if(~l&1) sum+=s[l^1]; if(r&1) sum+=s[r^1]; } return sum; } int Qsum(int x,int y){ int sum=0; while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]) swap(x,y); sum+=push_sum(id[top[x]],id[x]); x=fa[top[x]]; } if(id[x]>id[y]) swap(x,y); sum+=push_sum(id[x],id[y]); return sum; } int main(){ int x,y; n=read(); ly=1; while(ly<=n+2) ly<<=1; for(int i=1;i<n;i++) x=read(),y=read(),insert(x,y); f1(1); f2(1,1); for(int i=1;i<=n;i++) x=read(),modify(id[i],x); m=read(); for(int i=1;i<=m;i++){ scanf("%s",ch); x=read(); y=read(); if(ch[0]=='C') modify(id[x],y); else if(ch[1]=='M') printf("%d\n",Qmax(x,y)); else printf("%d\n",Qsum(x,y)); } return 0; }
3——lct(link-cut-tree)
#include<cstdio> #include<cstring> #include<algorithm> #define LL long long using namespace std; const int M=50007; int read(){ int ans=0,f=1,c=getchar(); while(c<'0'||c>'9'){if(c=='-') f=-1; c=getchar();} while(c>='0'&&c<='9'){ans=ans*10+(c-'0'); c=getchar();} return ans*f; } int n,m,c[M][2],fa[M],a[M],b[M]; LL v[M],sum[M],mx[M]; bool rev[M]; bool isrt(int x){return c[fa[x]][0]!=x&&c[fa[x]][1]!=x;} void up(int x){ if(!x) return ; mx[x]=v[x]; sum[x]=v[x]; int l=c[x][0],r=c[x][1]; if(l) mx[x]=max(mx[x],mx[l]),sum[x]+=sum[l]; if(r) mx[x]=max(mx[x],mx[r]),sum[x]+=sum[r]; } void down(int x){ if(!rev[x]) return ; rev[x]=0; int l=c[x][0],r=c[x][1]; if(l) swap(c[l][0],c[l][1]),rev[l]^=1; if(r) swap(c[r][0],c[r][1]),rev[r]^=1; } void rotate(int x){ int y=fa[x],z=fa[y],l=0,r=1; if(c[y][1]==x) l=1,r=0; if(!isrt(y)) c[z][c[z][1]==y]=x; fa[y]=x; fa[x]=z; fa[c[x][r]]=y; c[y][l]=c[x][r]; c[x][r]=y; up(y); up(x); } int st[M],top=0,S; void splay(int x){ st[++top]=x; for(int i=x;!isrt(i);i=fa[i]) st[++top]=fa[i]; while(top) down(st[top--]); while(!isrt(x)){ int y=fa[x],z=fa[y]; if(!isrt(y)){ if(c[z][0]==y^c[y][0]==x) rotate(x); else rotate(y); } rotate(x); } } void acs(int x0){ for(int x=x0,y=0;x;splay(x),c[x][1]=y,up(x),y=x,x=fa[x]); splay(x0); } void mrt(int x){acs(x); swap(c[x][0],c[x][1]); rev[x]^=1;} void link(int x,int y){mrt(x); fa[x]=y;} void spl(int x,int y){mrt(x); acs(y);} int main() { int x,y; n=read(); for(int i=1;i<n;i++) a[i]=read(),b[i]=read(); for(int i=1;i<=n;i++) v[i]=read(); for(int i=1;i<n;i++) link(a[i],b[i]); m=read(); char ch[15]; for(int i=1;i<=m;i++){ scanf("%s",ch); x=read(); y=read(); if(ch[1]=='H') acs(x),v[x]=y; if(ch[1]=='S') spl(x,y),printf("%lld\n",sum[y]); if(ch[1]=='M') spl(x,y),printf("%lld\n",mx[y]); } return 0; }