[ZJOI2008]树的统计(模板)
题目描述
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。
我们将以下面的形式来要求你对这棵树完成一些操作:
I. CHANGE u t : 把结点u的权值改为t
II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值
III. QSUM u v: 询问从点u到点v的路径上的节点的权值和
注意:从点u到点v的路径上的节点包括u和v本身
输入格式
输入文件的第一行为一个整数n,表示节点的个数。
接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。
接下来一行n个整数,第i个整数wi表示节点i的权值。
接下来1行,为一个整数q,表示操作的总数。
接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
输出格式
对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。
就是将树剖链,挂线段树上
常规操作
不过今天在跳链时,发现程序的一个小bug;
while(top[x]!=top[y]) { if(deep[top[x]]<deep[top[y]])x^=y^=x^=y;
//就是上一行,交换x,y的条件是保证top[x]的深度要小于top[y],以免跳多了
query_sum(1,1,n,sub[top[x]],sub[x]); x=fa[top[x]]; } if(deep[x]>deep[y])x^=y^=x^=y; query_sum(1,1,n,sub[x],sub[y]);
还是挂一下代码吧
#include<bits/stdc++.h> #define re return #define inc(i,l,r) for(int i=l;i<=r;++i) using namespace std; template<typename T>inline void rd(T&x) { char c;bool f=0; while((c=getchar())<'0'||c>'9')if(c=='-')f=1; x=c^48; while((c=getchar())>='0'&&c<='9')x=x*10+(c^48); if(f)x=-x; } const int maxn=30005; int n,m,tot,k,hd[maxn],val[maxn],ans; int size[maxn],deep[maxn],fa[maxn],son[maxn]; int top[maxn],rev[maxn],sub[maxn]; struct node { int to,nt; }e[maxn<<1]; inline void add(int x,int y) { e[++k].to=y;e[k].nt=hd[x];hd[x]=k; e[++k].to=x;e[k].nt=hd[y];hd[y]=k; } inline void dfs1(int x) { size[x]=1; deep[x]=deep[fa[x]]+1; for(int i=hd[x];i;i=e[i].nt) { int v=e[i].to; if(fa[v])continue; fa[v]=x; dfs1(v); size[x]+=size[v]; if(size[v]>size[son[x]])son[x]=v; } } inline void dfs2(int x,int topf) { top[x]=topf; sub[x]=++tot; rev[tot]=x; if(son[x]) dfs2(son[x],topf); for(int i=hd[x];i;i=e[i].nt) { int v=e[i].to; if(!top[v]) dfs2(v,v); } } #define ls rt<<1 #define rs rt<<1|1 int t[maxn<<2],sum[maxn<<2]; inline void pushup(int rt) { sum[rt]=sum[ls]+sum[rs]; t[rt]=max(t[ls],t[rs]); } inline void build(int rt,int l,int r) { if(l==r) { sum[rt]=t[rt]=val[rev[l]]; re ; } int mid=(l+r)>>1; build(ls,l,mid); build(rs,mid+1,r); pushup(rt); } inline void query_sum(int rt,int l,int r,int x,int y) { if(x<=l&&r<=y) { ans+=sum[rt]; re; } int mid=(l+r)>>1; if(x<=mid)query_sum(ls,l,mid,x,y); if(y>mid)query_sum(rs,mid+1,r,x,y); } inline void query_maxx(int rt,int l,int r,int x,int y) { if(x<=l&&r<=y) { ans=max(ans,t[rt]); re ; } int mid=(l+r)>>1; if(x<=mid)query_maxx(ls,l,mid,x,y); if(y>mid) query_maxx(rs,mid+1,r,x,y); } inline void change(int rt,int l,int r,int x,int add) { if(l==r) { sum[rt]=t[rt]=add; re ; } int mid=(l+r)>>1; if(x<=mid)change(ls,l,mid,x,add); else change(rs,mid+1,r,x,add); pushup(rt); } int main() { int x,y; rd(n); inc(i,2,n) { rd(x),rd(y); add(x,y); } inc(i,1,n)rd(val[i]); fa[1]=1; dfs1(1); dfs2(1,1); build(1,1,n); char opt[10]; rd(m); inc(i,1,m) { scanf("%s",opt); rd(x),rd(y); if(opt[0]=='C') change(1,1,n,sub[x],y); else if(opt[1]=='S') { ans=0; while(top[x]!=top[y]) { if(deep[top[x]]<deep[top[y]])x^=y^=x^=y; query_sum(1,1,n,sub[top[x]],sub[x]); x=fa[top[x]]; } if(deep[x]>deep[y])x^=y^=x^=y; query_sum(1,1,n,sub[x],sub[y]); printf("%d\n",ans); } else { ans=-3100055; while(top[x]!=top[y]) { if(deep[top[x]]<deep[top[y]])x^=y^=x^=y; query_maxx(1,1,n,sub[top[x]],sub[x]); x=fa[top[x]]; } if(deep[x]>deep[y])x^=y^=x^=y; query_maxx(1,1,n,sub[x],sub[y]); printf("%d\n",ans); } } re 0; }