[ZJOI2008]树的统计(树链剖分)
原题
Solution
这道题目不是看到就发现是一道树链剖分的裸题吗?
#include<stdio.h>
#include<stdlib.h>
#define ll long long
ll max(ll a,ll b){
if(a>b)return a;
return b;
}
void swap(int &a,int &b){
int tmp=a;a=b;b=tmp;
}
int gi(){
int sum=0,f=1;char ch=getchar();
while(ch>'9' || ch<'0'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0' && ch<='9'){sum=sum*10+ch-'0';ch=getchar();}
return f*sum;
}
ll gl(){
ll sum=0,f=1;char ch=getchar();
while(ch>'9' || ch<'0'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0' && ch<='9'){sum=sum*10+ch-'0';ch=getchar();}
return f*sum;
}
const int maxn=100010;
struct node{
int to,nxt;
}e[maxn*2];
int cnt,front[maxn],root,son[maxn],dep[maxn],fa[maxn],siz[maxn],id[maxn],top[maxn],num;
ll b[maxn],w[maxn];
void Add(int u,int v){
e[++cnt].to=v;e[cnt].nxt=front[u];front[u]=cnt;
}
void dfs1(int u,int f,int d){
fa[u]=f;dep[u]=d;siz[u]=1;
for(int i=front[u];i;i=e[i].nxt){
int v=e[i].to;
if(v!=f){
dfs1(v,u,d+1);
siz[u]+=siz[v];
if(!son[u] || siz[son[u]]<siz[v])son[u]=v;
}
}
}
void dfs2(int u,int f){
top[u]=f;id[u]=++num;b[num]=w[u];
if(!son[u])return;
dfs2(son[u],f);
for(int i=front[u];i;i=e[i].nxt){
int v=e[i].to;
if(v!=fa[u] && v!=son[u])dfs2(v,v);
}
}
struct tree{
ll max,val;
}t[4*maxn];
#define ls o<<1
#define rs o<<1|1
void pushup(int o){
t[o].val=t[ls].val+t[rs].val;
t[o].max=max(t[ls].max,t[rs].max);
}
void build(int o,int l,int r){
if(l==r){
t[o].val=t[o].max=b[l];return;
}
int mid=(l+r)>>1;
build(ls,l,mid);build(rs,mid+1,r);
pushup(o);
}
void update(int o,int l,int r,int pos,ll k){
if(l==r){
t[o].val=k;t[o].max=k;return;
}
int mid=(l+r)>>1;
if(pos<=mid)update(ls,l,mid,pos,k);
else update(rs,mid+1,r,pos,k);
pushup(o);
}
ll query1(int o,int l,int r,int posl,int posr){
if(posl<=l && r<=posr)return t[o].val;
int mid=(l+r)>>1;
if(mid>=posr)return query1(ls,l,mid,posl,posr);
if(mid<posl)return query1(rs,mid+1,r,posl,posr);
return query1(ls,l,mid,posl,mid)+query1(rs,mid+1,r,mid+1,posr);
}
ll query2(int o,int l,int r,int posl,int posr){
if(posl<=l && r<=posr)return t[o].max;
int mid=(l+r)>>1;
if(mid>=posr)return query2(ls,l,mid,posl,posr);
if(mid<posl)return query2(rs,mid+1,r,posl,posr);
return max(query2(ls,l,mid,posl,mid),query2(rs,mid+1,r,mid+1,posr));
}
ll sum(int x,int y){
ll ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
ans+=query1(1,1,num,id[top[x]],id[x]);
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
ans+=query1(1,1,num,id[x],id[y]);
return ans;
}
ll big(int x,int y){
ll ans=-30000;
int s=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
ans=max(ans,query2(1,1,num,id[top[x]],id[x]));
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
ans=max(ans,query2(1,1,num,id[x],id[y]));
return ans;
}
int main(){
int i,j,k,n,m;
n=gi();
for(i=1;i<n;i++){
int u=gi(),v=gi();
Add(u,v);Add(v,u);
}
for(i=1;i<=n;i++)w[i]=gl();
root=1;
dfs1(root,0,1);#include<stdio.h>
#include<stdlib.h>
#define ll long long
ll max(ll a,ll b){
if(a>b)return a;
return b;
}
void swap(int &a,int &b){
int tmp=a;a=b;b=tmp;
}
int gi(){
int sum=0,f=1;char ch=getchar();
while(ch>'9' || ch<'0'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0' && ch<='9'){sum=sum*10+ch-'0';ch=getchar();}
return f*sum;
}
ll gl(){
ll sum=0,f=1;char ch=getchar();
while(ch>'9' || ch<'0'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0' && ch<='9'){sum=sum*10+ch-'0';ch=getchar();}
return f*sum;
}
const int maxn=100010;
struct node{
int to,nxt;
}e[maxn*2];
int cnt,front[maxn],root,son[maxn],dep[maxn],fa[maxn],siz[maxn],id[maxn],top[maxn],num;
ll b[maxn],w[maxn];
void Add(int u,int v){
e[++cnt].to=v;e[cnt].nxt=front[u];front[u]=cnt;
}
void dfs1(int u,int f,int d){
fa[u]=f;dep[u]=d;siz[u]=1;
for(int i=front[u];i;i=e[i].nxt){
int v=e[i].to;
if(v!=f){
dfs1(v,u,d+1);
siz[u]+=siz[v];
if(!son[u] || siz[son[u]]<siz[v])son[u]=v;
}
}
}
void dfs2(int u,int f){
top[u]=f;id[u]=++num;b[num]=w[u];
if(!son[u])return;
dfs2(son[u],f);
for(int i=front[u];i;i=e[i].nxt){
int v=e[i].to;
if(v!=fa[u] && v!=son[u])dfs2(v,v);
}
}
struct tree{
ll max,val;
}t[4*maxn];
#define ls o<<1
#define rs o<<1|1
void pushup(int o){
t[o].val=t[ls].val+t[rs].val;
t[o].max=max(t[ls].max,t[rs].max);
}
void build(int o,int l,int r){
if(l==r){
t[o].val=t[o].max=b[l];return;
}
int mid=(l+r)>>1;
build(ls,l,mid);build(rs,mid+1,r);
pushup(o);
}
void update(int o,int l,int r,int pos,ll k){
if(l==r){
t[o].val=k;t[o].max=k;return;
}
int mid=(l+r)>>1;
if(pos<=mid)update(ls,l,mid,pos,k);
else update(rs,mid+1,r,pos,k);
pushup(o);
}
ll query1(int o,int l,int r,int posl,int posr){
if(posl<=l && r<=posr)return t[o].val;
int mid=(l+r)>>1;
if(mid>=posr)return query1(ls,l,mid,posl,posr);
if(mid<posl)return query1(rs,mid+1,r,posl,posr);
return query1(ls,l,mid,posl,mid)+query1(rs,mid+1,r,mid+1,posr);
}
ll query2(int o,int l,int r,int posl,int posr){
if(posl<=l && r<=posr)return t[o].max;
int mid=(l+r)>>1;
if(mid>=posr)return query2(ls,l,mid,posl,posr);
if(mid<posl)return query2(rs,mid+1,r,posl,posr);
return max(query2(ls,l,mid,posl,mid),query2(rs,mid+1,r,mid+1,posr));
}
ll sum(int x,int y){
ll ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
ans+=query1(1,1,num,id[top[x]],id[x]);
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
ans+=query1(1,1,num,id[x],id[y]);
return ans;
}
ll big(int x,int y){
ll ans=-30000;
int s=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
ans=max(ans,query2(1,1,num,id[top[x]],id[x]));
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
ans=max(ans,query2(1,1,num,id[x],id[y]));
return ans;
}
int main(){
int i,j,k,n,m;
n=gi();
for(i=1;i<n;i++){
int u=gi(),v=gi();
Add(u,v);Add(v,u);
}
for(i=1;i<=n;i++)w[i]=gl();
root=1;
dfs1(root,0,1);
dfs2(root,root);
build(1,1,n);
scanf("%d",&m);
for(i=1;i<=m;i++){
char op[10];scanf("%s",op);
if(op[0]=='C'){
int u;ll t;scanf("%d%lld",&u,&t);
update(1,1,n,id[u],t);
}
else if(op[1]=='S'){
int u,v;scanf("%d%d",&u,&v);
printf("%lld\n",sum(u,v));
}
else{
int u,v;scanf("%d%d",&u,&v);
printf("%lld\n",big(u,v));
}
}
return 0;
}
dfs2(root,root);
build(1,1,n);
scanf("%d",&m);
for(i=1;i<=m;i++){
char op[10];scanf("%s",op);
if(op[0]=='C'){
int u;ll t;scanf("%d%lld",&u,&t);
update(1,1,n,id[u],t);
}
else if(op[1]=='S'){
int u,v;scanf("%d%d",&u,&v);
printf("%lld\n",sum(u,v));
}
else{
int u,v;scanf("%d%d",&u,&v);
printf("%lld\n",big(u,v));
}
}
return 0;
}