ZJOI2008]树的统计(树链剖分,线段树)
题目描述
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。
我们将以下面的形式来要求你对这棵树完成一些操作:
I. CHANGE u t : 把结点u的权值改为t
II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值
III. QSUM u v: 询问从点u到点v的路径上的节点的权值和
注意:从点u到点v的路径上的节点包括u和v本身
输入输出格式
输入格式:
输入文件的第一行为一个整数n,表示节点的个数。
接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。
接下来一行n个整数,第i个整数wi表示节点i的权值。
接下来1行,为一个整数q,表示操作的总数。
接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
输出格式:
对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果
思路:
树剖板子题
将树剖好后跑线段树查询即可
代码:
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #define rii register int i #define rij register int j #define int long long using namespace std; int n,head[200005],size[200005],f[200005],zs[200005],bnt; int top[200005],nid[200005],nsd[200005],cnt,val[200005]; int nval[200005],q; struct ljb{ int to,nxt; }y[400005]; struct xds{ int maxn,sum; }x[1000005]; inline void add(int from,int to) { bnt++; y[bnt].to=to; y[bnt].nxt=head[from]; head[from]=bnt; } void dfs1(int wz,int fa,int sd) { f[wz]=fa; nsd[wz]=sd; size[wz]=1; int maxn=0; for(rii=head[wz];i!=0;i=y[i].nxt) { int to=y[i].to; if(to!=fa) { dfs1(to,wz,sd+1); size[wz]+=size[to]; if(size[to]>maxn) { zs[wz]=to; maxn=size[to]; } } } } void dfs2(int wz,int ntop) { cnt++; nid[wz]=cnt; nval[cnt]=val[wz]; top[wz]=ntop; if(zs[wz]==0) { return; } dfs2(zs[wz],ntop); for(rii=head[wz];i!=0;i=y[i].nxt) { int to=y[i].to; if(zs[wz]!=to&&f[wz]!=to) { dfs2(to,to); } } } void build(int l,int r,int bh) { if(l==r) { x[bh].sum=nval[l]; x[bh].maxn=nval[l]; return; } int mid=(l+r)/2; build(l,mid,bh*2); build(mid+1,r,bh*2+1); x[bh].sum=x[bh*2].sum+x[bh*2+1].sum; x[bh].maxn=max(x[bh*2].maxn,x[bh*2+1].maxn); } void change(int wz,int nl,int nr,int val,int bh) { if(nl==nr&&nl==wz) { x[bh].maxn=val; x[bh].sum=val; return; } int mid=(nl+nr)/2; if(wz<=mid) { change(wz,nl,mid,val,bh*2); } else { change(wz,mid+1,nr,val,bh*2+1); } x[bh].maxn=max(x[bh*2].maxn,x[bh*2+1].maxn); x[bh].sum=x[bh*2].sum+x[bh*2+1].sum; } int querym(int l,int r,int nl,int nr,int bh) { if(l<nl) { l=nl; } if(r>nr) { r=nr; } if(l==nl&&r==nr) { return x[bh].maxn; } int mid=(nl+nr)/2; int val=-500000; if(l<=mid) { int a1=querym(l,r,nl,mid,bh*2); val=max(a1,val); } if(r>mid) { int a2=querym(l,r,mid+1,nr,bh*2+1); val=max(val,a2); } return val; } int qmax(int from,int to) { int ans=-500000; while(top[from]!=top[to]) { if(nsd[top[from]]<nsd[top[to]]) { swap(from,to); } int res=0; res=querym(nid[top[from]],nid[from],1,n,1); ans=max(ans,res); from=f[top[from]]; } if(nsd[from]>nsd[to]) { swap(from,to); } int res=0; res=querym(nid[from],nid[to],1,n,1); ans=max(ans,res); return ans; } int querys(int l,int r,int nl,int nr,int bh) { if(l<nl) { l=nl; } if(r>nr) { r=nr; } if(l==nl&&r==nr) { return x[bh].sum; } int mid=(nl+nr)/2; int val=0; if(l<=mid) { val+=querys(l,r,nl,mid,bh*2); } if(r>mid) { val+=querys(l,r,mid+1,nr,bh*2+1); } return val; } int qsum(int from,int to) { int ans=0; while(top[from]!=top[to]) { if(nsd[top[from]]<nsd[top[to]]) { swap(from,to); } int res=0; res=querys(nid[top[from]],nid[from],1,n,1); ans+=res; from=f[top[from]]; } if(nsd[from]>nsd[to]) { swap(from,to); } int res=0; res=querys(nid[from],nid[to],1,n,1); ans+=res; return ans; } signed main() { for(rii=1;i<=400000;i++) { x[i].maxn=-500000; } scanf("%lld",&n); for(rii=1;i<=n-1;i++) { int from,to; scanf("%lld%lld",&from,&to); add(from,to); add(to,from); } dfs1(1,1,0); for(rii=1;i<=n;i++) { scanf("%lld",&val[i]); } dfs2(1,1); build(1,n,1); scanf("%lld",&q); for(rii=1;i<=q;i++) { int from,to,val; string s; char c=getchar(); while(c<'A'||c>'Z') { c=getchar(); } while(c>='A'&&c<='Z') { s+=c; c=getchar(); } if(s=="CHANGE") { scanf("%lld%lld",&from,&val); change(nid[from],1,n,val,1); } if(s=="QMAX") { scanf("%lld%lld",&from,&to); int ltt=qmax(from,to); printf("%lld\n",ltt); } if(s=="QSUM") { scanf("%lld%lld",&from,&to); int ltt=qsum(from,to); printf("%lld\n",ltt); } } }