【ZJOI2008】树的统计
Description
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。
我们将以下面的形式来要求你对这棵树完成一些操作:
I. CHANGE u t : 把结点u的权值改为t
II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值
III. QSUM u v: 询问从点u到点v的路径上的节点的权值和
注意:从点u到点v的路径上的节点包括u和v本身
树链剖分模板题,这里写的很好。
简单补充一下,就是把一棵树剖成一条链,然后链上做线段树即可。
Code
#include<iostream>
#include<cstdio>
#include<cstdlib>
#define fo(i,j,k) for(int i=j;i<=k;i++)
#define fd(i,j,k) for(int i=j;i>=k;i--)
#define N 90010
#define inf 1000000000
using namespace std;
int w[N];
int to[N],next[N],last[N],num=0;
int dep[N],siz[N];
int top[N],son[N];
int fa[N];
int pos[N],z[N];
struct node{
int s,mx;
}tr[N*4];
void link(int x,int y)
{
num++;
to[num]=y;
next[num]=last[x];
last[x]=num;
}
void find(int x)//寻找重边
{
siz[x]=1;son[x]=0;
for(int i=last[x];i;i=next[i])
{
int v=to[i];
if(v!=fa[x])
{
dep[v]=dep[x]+1;
fa[v]=x;
find(v);
if(siz[v]>siz[son[x]]) son[x]=v;
siz[x]+=siz[v];
}
}
}
int cnt=0;
void dfs(int x,int t)//剖链的过程
{
pos[x]=++cnt;
top[x]=t;
z[pos[x]]=x;
if(son[x]) dfs(son[x],t);
for(int i=last[x];i;i=next[i])
{
int v=to[i];
if(v!=fa[x] && v!=son[x])
dfs(v,v);
}
}
void build(int v,int l,int r)
{
if(l==r)
{
tr[v].s=tr[v].mx=w[z[l]];
return;
}
int mid=(l+r)/2;
build(v*2,l,mid);
build(v*2+1,mid+1,r);
tr[v].s=tr[v*2].s+tr[v*2+1].s;
tr[v].mx=max(tr[v*2].mx,tr[v*2+1].mx);
}
void change(int v,int l,int r,int x,int p)
{
if(l==r && l==x)
{
tr[v].s=tr[v].mx=p;
return;
}
int mid=(l+r)/2;
if(x<=mid) change(v*2,l,mid,x,p);
else change(v*2+1,mid+1,r,x,p);
tr[v].s=tr[v*2].s+tr[v*2+1].s;
tr[v].mx=max(tr[v*2].mx,tr[v*2+1].mx);
}
int qsum(int v,int l,int r,int x,int y)
{
if(l==x && r==y) return tr[v].s;
int mid=(l+r)/2;
if(y<=mid) return qsum(v*2,l,mid,x,y);
else if(x>mid) return qsum(v*2+1,mid+1,r,x,y);
else return qsum(v*2,l,mid,x,mid)+qsum(v*2+1,mid+1,r,mid+1,y);
}
int qmax(int v,int l,int r,int x,int y)
{
if(l==x && r==y) return tr[v].mx;
int mid=(l+r)/2;
if(y<=mid) return qmax(v*2,l,mid,x,y);
else if(x>mid) return qmax(v*2+1,mid+1,r,x,y);
else return max(qmax(v*2,l,mid,x,mid),qmax(v*2+1,mid+1,r,mid+1,y));
}
int findmax(int u,int v)
{
int f1=top[u],f2=top[v];
int tmp=-inf;
while(f1!=f2)
{
if(dep[f1]<dep[f2])
{
swap(f1,f2);
swap(u,v);
}
tmp=max(tmp,qmax(1,1,cnt,pos[f1],pos[u]));
u=fa[f1];
f1=top[u];
}
if(dep[u]>dep[v]) swap(u,v);
return max(tmp,qmax(1,1,cnt,pos[u],pos[v]));
}
int findsum(int u,int v)
{
int f1=top[u],f2=top[v];
int tmp=0;
while(f1!=f2)
{
if(dep[f1]<dep[f2])
{
swap(f1,f2);
swap(u,v);
}
tmp+=qsum(1,1,cnt,pos[f1],pos[u]);
u=fa[f1];
f1=top[u];
}
if(dep[u]>dep[v]) swap(u,v);
return tmp+qsum(1,1,cnt,pos[u],pos[v]);
}
int main()
{
int n;
cin>>n;
fo(i,1,n-1)
{
int x,y;
scanf("%d %d",&x,&y);
link(x,y);
link(y,x);
}
fo(i,1,n) scanf("%d",&w[i]);
find(1);
dfs(1,1);
build(1,1,cnt);
int Q;
cin>>Q;
while(Q--)
{
char s[11];
scanf("%s",s);
if(s[0]=='C')
{
int x,t;
scanf("%d %d",&x,&t);
change(1,1,cnt,pos[x],t);
}
else
{
int x,y;
scanf("%d %d",&x,&y);
if(s[1]=='M') printf("%d\n",findmax(x,y));
else printf("%d\n",findsum(x,y));
}
}
}