树的统计
思路
用线段树维护单点修改,区间查询最大值,区间和
注意
记得修改的点对应的下标在线段树上是\(id[u]\),因为我们的线段树是按\(dfn\)形成的
和其他的树链剖分题是一样的
注意
当查询区间最大值的时候,一定记得先把下标设成负无穷,因为我们的答案有可能是负数
代码
#include<bits/stdc++.h>
using namespace std;
const int N=200010;
int ne[N],head[N],ver[N],idx;
int id[N],cnt,nw[N];
int top[N],son[N],sz[N],dep[N],fa[N];
int n,m;
int w[N];
string s;
const int inf=0x3f3f3f3f;
void add(int u,int v)
{
ne[idx]=head[u];
ver[idx]=v;
head[u]=idx;
idx++;
}
void dfs1(int u,int father,int depth)
{
fa[u]=father;
dep[u]=depth;
sz[u]=1;
for(int i=head[u];i!=-1;i=ne[i])
{
int j=ver[i];
if(j==father)continue;
dfs1(j,u,depth+1);
sz[u]+=sz[j];
if(sz[son[u]]<sz[j]) son[u]=j;
}
}
void dfs2(int u,int t)
{
top[u]=t;
id[u]=++cnt;
nw[cnt]=w[u];
if(!son[u])return ;
dfs2(son[u],t);
for(int i=head[u];i!=-1;i=ne[i])
{
int j=ver[i];
if(j==fa[u]||j==son[u]) continue;
dfs2(j,j);
}
}
struct node{
int l,r;
long long sum,maxx;
}tr[N*4];
void pushup(int p)
{
tr[p].maxx=max(tr[p<<1].maxx,tr[p<<1|1].maxx);
tr[p].sum=tr[p<<1].sum+tr[p<<1|1].sum;
}
void build(int p,int l,int r)//建树
{
tr[p]={l,r,nw[r],nw[r]};
if(l==r)return ;
int mid=(l+r)/2;
build(p<<1,l,mid);
build(p<<1|1,mid+1,r);
pushup(p);//记得pushup
}
void update(int p,int x,int d)
{
if(tr[p].l==tr[p].r) {tr[p].maxx=d,tr[p].sum=d;return ;}
int mid=(tr[p].l+tr[p].r)/2;
if(x<=mid) update(p<<1,x,d);
if(x>mid) update(p<<1|1,x,d);
pushup(p);//记得pushup
}
long long query_max(int p,int l,int r)//区间查询最大值
{
long long ans=-inf;
if(tr[p].l>=l&&tr[p].r<=r)
{
return tr[p].maxx;
}
int mid=(tr[p].l+tr[p].r)/2;
if(l<=mid) ans=max(ans,query_max(p<<1,l,r));
if(r>mid) ans=max(ans,query_max(p<<1|1,l,r));
return ans;
}
long long query_sum(int p,int l,int r)//区间查询和
{
long long ans=0;
if(tr[p].l>=l&&tr[p].r<=r)
{
return tr[p].sum;
}
int mid=(tr[p].l+tr[p].r)/2;
if(l<=mid) ans+=query_sum(p<<1,l,r);
if(r>mid) ans+=query_sum(p<<1|1,l,r);
return ans;
}
long long tree_query_max(int u,int v)//链上查询最大值
{
long long ans=-inf;
while(top[u]!=top[v])
{
if(dep[top[u]]<dep[top[v]])
swap(u,v);
ans=max(ans,query_max(1,id[top[u]],id[u]));
u=fa[top[u]];
}
if(dep[u]<dep[v]) swap(u,v);
ans=max(ans,query_max(1,id[v],id[u]));
return ans;
}
long long tree_query_sum(int u,int v)//链上查询区间和
{
long long ans=0;
while(top[u]!=top[v])
{
if(dep[top[u]]<dep[top[v]])
swap(u,v);
ans+=query_sum(1,id[top[u]],id[u]);
u=fa[top[u]];
}
if(dep[u]<dep[v]) swap(u,v);
ans+=query_sum(1,id[v],id[u]);
return ans;
}
inline int read()
{
int x=0;
int f=1;
char ch;
ch=getchar();
while(ch>'9'||ch<'0')
{
if(ch=='-')
f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=x*10,x=x+ch-'0';
ch=getchar();
}
return x*f;
}
int main()
{
// freopen("1.in","r",stdin);
// freopen("res1.out","w",stdout);
memset(head,-1,sizeof(head));
n=read();
for(int i=1;i<=n-1;i++)
{
int a,b;
a=read();
b=read();
add(a,b);
add(b,a);
}
for(int i=1;i<=n;i++)
w[i]=read();
dfs1(1,-1,1);
dfs2(1,1);
build(1,1,n);
m=read();
for(int i=1;i<=m;i++)
{
int a,b;
cin>>s;
a=read();
b=read();
if(s=="QMAX")
{
cout<<tree_query_max(a,b)<<endl;
}
else if(s=="QSUM")
{
cout<<tree_query_sum(a,b)<<endl;
}
else {
update(1,id[a],b);
}
}
return 0;
}