2008ZJOI树的统计
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。
我们将以下面的形式来要求你对这棵树完成一些操作:
- I. CHANGE u t : 把结点u的权值改为t
- II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值
- III. QSUM u v: 询问从点u到点v的路径上的节点的权值和
注意:从点u到点v的路径上的节点包括u和v本身
输入文件的第一行为一个整数n,表示节点的个数。
接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。
接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。
接下来1行,为一个整数q,表示操作的总数。
接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。
4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
4
1
2
2
10
6
5
6
5
16
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。
树链剖分模板题,使用线段树维护
代码中变量含义:
n,节点个数 head[N],链表;
fa[N],存储父节点 son[N],统计包括自己在内的儿子个数
id[N], 给每个节点按树链剖分的顺序重新编号 dep[N],节点深度
bl[N],该节点所在重链的顶部节点 a[N],初始节点权值
struct node{int to,next;}e[N*2]; 链表
struct tree{int l,r,sum,maxx,size;}tr[N*2]; 线段树
注:代码中采用2*n空间方式建立线段树,k的左儿子为k+1,右儿子为k+k左儿子节点数*2
dfs_cnt,建立线段树时节点编号 cnt,链表 sz,树链剖分中节点访问顺序
#include<cstdio> #include<algorithm> #define N 30001 using namespace std; int n,head[N]; int fa[N],son[N],id[N],dep[N],bl[N],a[N]; struct node{int to,next;}e[N*2]; struct tree{int l,r,sum,maxx,size;}tr[N*2]; int dfs_cnt,cnt,sz; inline void insert(int u,int v) { e[++cnt].to=v;e[cnt].next=head[u];head[u]=cnt; e[++cnt].to=u;e[cnt].next=head[v];head[v]=cnt; } void init() { scanf("%d",&n); int u,v; for(int i=1;i<n;i++) { scanf("%d%d",&u,&v); insert(u,v); } for(int i=1;i<=n;i++) scanf("%d",&a[i]); } inline void build(int l,int r) { dfs_cnt++;tr[dfs_cnt].l=l;tr[dfs_cnt].r=r;tr[dfs_cnt].size=r-l+1; if(l==r) return; int mid=l+r>>1; build(l,mid);build(mid+1,r); } inline void dfs1(int k) { son[k]++; for(int i=head[k];i;i=e[i].next) { if(e[i].to==fa[k]) continue; fa[e[i].to]=k; dep[e[i].to]=dep[k]+1; dfs1(e[i].to); son[k]+=son[e[i].to]; } } inline void dfs2(int k,int chain) { int m=0;sz++; id[k]=sz; bl[k]=chain; for(int i=head[k];i;i=e[i].next) { if(e[i].to==fa[k]) continue; if(son[e[i].to]>son[m]) m=e[i].to; } if(!m) return; dfs2(m,chain); for(int i=head[k];i;i=e[i].next) { if(e[i].to==m||e[i].to==fa[k]) continue; dfs2(e[i].to,e[i].to); } } inline void change(int k,int pos,int w) { if(tr[k].l==tr[k].r) {tr[k].sum=tr[k].maxx=w;return;} int mid=tr[k].l+tr[k].r>>1,l=k+1,r=k+tr[k+1].size*2; if(pos<=mid) change(l,pos,w); else change(r,pos,w); tr[k].sum=tr[l].sum+tr[r].sum; tr[k].maxx=max(tr[l].maxx,tr[r].maxx); } inline int querymx(int k,int opl,int opr) { if(tr[k].l==opl&&tr[k].r==opr) return tr[k].maxx; int mid=tr[k].l+tr[k].r>>1,l=k+1,r=k+tr[k+1].size*2; if(opr<=mid) return querymx(l,opl,opr); else if(opl>mid) return querymx(r,opl,opr); else return max(querymx(l,opl,mid),querymx(r,mid+1,opr)); } inline int querysum(int k,int opl,int opr) { if(tr[k].l==opl&&tr[k].r==opr) return tr[k].sum; int mid=tr[k].l+tr[k].r>>1,l=k+1,r=k+tr[k+1].size*2; if(opr<=mid) return querysum(l,opl,opr); else if(opl>mid) return querysum(r,opl,opr); else return querysum(l,opl,mid)+querysum(r,mid+1,opr); } inline void solvemx(int u,int v) { int ans=-0x7fffffff; while(bl[u]!=bl[v]) { if(dep[bl[u]]<dep[bl[v]]) swap(u,v); ans=max(ans,querymx(1,id[bl[u]],id[u])); u=fa[bl[u]]; } if(id[u]>id[v]) swap(u,v); ans=max(ans,querymx(1,id[u],id[v])); printf("%d\n",ans); } inline void solvesum(int u,int v) { int ans=0; while(bl[u]!=bl[v]) { if(dep[bl[u]]<dep[bl[v]]) swap(u,v); ans+=querysum(1,id[bl[u]],id[u]); u=fa[bl[u]]; } if(id[u]>id[v]) swap(u,v); ans+=querysum(1,id[u],id[v]); printf("%d\n",ans); } void solve() { build(1,n); for(int i=1;i<=n;i++) change(1,id[i],a[i]); int q,u,v;char c[7]; scanf("%d",&q); for(int i=1;i<=q;i++) { scanf("%s%d%d",c,&u,&v); if(c[0]=='C') change(1,id[u],v); else if (c[1]=='M') solvemx(u,v); else solvesum(u,v); } } int main() { init(); dfs1(1); dfs2(1,1); solve(); }