【树链剖分模板】bzoj1036 树的统计
#include<cstdio>
#include<algorithm>
using namespace std;
const int N=30000+5;
int n,v[N];
int num,last[N],nxt[2*N],ver[2*N];
inline void add(int x,int y)
{nxt[++num]=last[x]; last[x]=num; ver[num]=y;
}
int siz[N],son[N],fa[N],deep[N];
void build(int x)
{siz[x]=1; son[x]=0;
for(int i=last[x];i;i=nxt[i])
{int y=ver[i];
if(y!=fa[x])
{fa[y]=x; deep[y]=deep[x]+1;
build(y);
if(siz[y]>siz[son[x]]) son[x]=y;
siz[x]+=siz[y];
}
}
}
int id,a[N],top[N],ord[N];
void dfs(int x)
{a[++id]=v[x]; ord[x]=id;
if(x==son[fa[x]]) top[x]=top[fa[x]];
else top[x]=x;
if(son[x]) dfs(son[x]);
for(int i=last[x];i;i=nxt[i])
if(ver[i]!=fa[x] && ver[i]!=son[x])
dfs(ver[i]);
}
struct point{int l,r,sum,maxx;}t[4*N];
void build(int i,int p,int q)
{t[i].l=p; t[i].r=q;
if(p==q) {t[i].maxx=t[i].sum=a[p]; return;}
int mid=p+q>>1;
build(2*i,p,mid);
build(2*i+1,mid+1,q);
t[i].sum=t[2*i].sum+t[2*i+1].sum;
t[i].maxx=max(t[2*i].maxx,t[2*i+1].maxx);
}
void change(int i,int p,int x)
{if(t[i].l==t[i].r && t[i].l==p) {t[i].maxx=t[i].sum=x; return;}
int mid=t[i].l+t[i].r>>1;
if(p<=mid) change(2*i,p,x);
else change(2*i+1,p,x);
t[i].sum=t[2*i].sum+t[2*i+1].sum;
t[i].maxx=max(t[2*i].maxx,t[2*i+1].maxx);
}
point ask(int i,int p,int q)
{if(p<=t[i].l && t[i].r<=q) return t[i];
int mid=t[i].l+t[i].r>>1;
if(q<=mid) return ask(2*i,p,q);
else if(p>mid) return ask(2*i+1,p,q);
else {point re,r1=ask(2*i,p,q),r2=ask(2*i+1,p,q);
re.sum=r1.sum+r2.sum; re.maxx=max(r1.maxx,r2.maxx);
return re;
}
}
inline point query(int x,int y)
{point re; re.sum=0; re.maxx=-30005;
while(top[x]!=top[y])
{if(deep[top[x]]<deep[top[y]]) swap(x,y);
point k=ask(1,ord[top[x]],ord[x]);
re.sum+=k.sum; re.maxx=max(re.maxx,k.maxx);
x=fa[top[x]];
}
if(deep[x]<deep[y]) swap(x,y);
point k=ask(1,ord[y],ord[x]);
re.sum+=k.sum; re.maxx=max(re.maxx,k.maxx);
return re;
}
int main()
{
scanf("%d",&n); int x,y,m; char op[15];
for(int i=1;i<n;i++)
{scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
for(int i=1;i<=n;i++) scanf("%d",&v[i]);
deep[1]=1; build(1); dfs(1);
build(1,1,id);
scanf("%d",&m);
while(m--)
{scanf("%s%d%d",op,&x,&y);
if(op[1]=='H') change(1,ord[x],y);
else if(op[1]=='M') {point ans=query(x,y);
printf("%d\n",ans.maxx);}
else if(op[1]=='S') {point ans=query(x,y);
printf("%d\n",ans.sum);}
}
return 0;
}