[ZJOI2008] 树的统计Count

题目链接:戳我
树链剖分。
注意一点就是维护最大值的时候最好写成下面代码里那个样子,要不然会因为可能左右区间没有的问题有奇奇怪怪的锅。
代码如下:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define MAXN 100010
using namespace std;
int n,m,t,tot;
int a[MAXN],wt[MAXN],head[MAXN],sum[MAXN],maxx[MAXN];
int son[MAXN],siz[MAXN],fa[MAXN],id[MAXN],dep[MAXN],top[MAXN];
char cur[10];
struct Edge{int nxt,to;}edge[MAXN<<1];
inline void add(int from,int to){edge[++t].nxt=head[from];edge[t].to=to;head[from]=t;}
inline void dfs1(int now,int pre)
{
    siz[now]=1;
    fa[now]=pre;
    dep[now]=dep[pre]+1;
    int maxx=-1;
    for(int i=head[now];i;i=edge[i].nxt)
    {
        int v=edge[i].to;
        if(v==pre) continue;
        dfs1(v,now);
        siz[now]+=siz[v];
        if(siz[v]>maxx) maxx=siz[v],son[now]=v;
    }
}
inline void dfs2(int now,int topf)
{
    id[now]=++tot;
    top[now]=topf;
    wt[tot]=a[now];
    if(son[now]) dfs2(son[now],topf);
    for(int i=head[now];i;i=edge[i].nxt)
    {
        int v=edge[i].to;
        if(v==fa[now]||v==son[now]) continue;
        dfs2(v,v);
    }
}
inline int ls(int x){return x<<1;}
inline int rs(int x){return x<<1|1;}
inline void push_up(int x)
{   
    sum[x]=sum[ls(x)]+sum[rs(x)];
    maxx[x]=max(maxx[ls(x)],maxx[rs(x)]);
}
inline void build(int x,int l,int r)
{
    if(l==r) {sum[x]=wt[l],maxx[x]=wt[l];return;}
    int mid=(l+r)>>1;
    build(ls(x),l,mid);
    build(rs(x),mid+1,r);
    push_up(x);
}
inline void update(int x,int l,int r,int pos,int k)
{
    if(l==r) {sum[x]=k,maxx[x]=k;return;}
    int mid=(l+r)>>1;
    if(pos<=mid) update(ls(x),l,mid,pos,k);
    else update(rs(x),mid+1,r,pos,k);
    push_up(x);
}
inline int query_sum(int x,int l,int r,int ll,int rr)
{
    if(ll<=l&&r<=rr) return sum[x];
    int mid=(l+r)>>1;
    int cur_ans=0;
    if(ll<=mid) cur_ans+=query_sum(ls(x),l,mid,ll,rr);
    if(mid<rr) cur_ans+=query_sum(rs(x),mid+1,r,ll,rr);
    return cur_ans;
}
inline int query_max(int x,int l,int r,int ll,int rr)
{
    if(ll<=l&&r<=rr) return maxx[x];
    int mid=(l+r)>>1;
    int ans=-2147483647;
    if(ll<=mid) ans=max(ans,query_max(ls(x),l,mid,ll,rr));
    if(mid<rr) ans=max(ans,query_max(rs(x),mid+1,r,ll,rr));
    return ans;
}
inline int solve_max(int x,int y)
{
    int maxx=(int)-1e9;
    while(top[x]!=top[y])
    {
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        maxx=max(maxx,query_max(1,1,n,id[top[x]],id[x]));
        x=fa[top[x]];
    }
    if(dep[x]<dep[y]) swap(x,y);
    maxx=max(maxx,query_max(1,1,n,id[y],id[x]));
    return maxx;
}
inline int solve_sum(int x,int y)
{
    int cur_ans=0;
    while(top[x]!=top[y])
    {
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        cur_ans+=query_sum(1,1,n,id[top[x]],id[x]);
        x=fa[top[x]];
    }
    if(dep[x]<dep[y]) swap(x,y);
    cur_ans+=query_sum(1,1,n,id[y],id[x]);
    return cur_ans;
}
int main()
{
    #ifndef ONLINE_JUDGE
    freopen("ce.in","r",stdin);
    #endif
    scanf("%d",&n);
    for(int i=1;i<n;i++)
    {
        int u,v;
        scanf("%d%d",&u,&v);
        add(u,v),add(v,u);
    }
    for(int i=1;i<=n;i++) scanf("%d",&a[i]);
    dfs1(1,1);
    dfs2(1,1);
    build(1,1,n);
    scanf("%d",&m);
    for(int i=1;i<=m;i++)
    {
        int x,y;
        scanf("%s%d%d",cur,&x,&y);
        if(cur[1]=='M') printf("%d\n",solve_max(x,y));
        else if(cur[1]=='S') printf("%d\n",solve_sum(x,y));
        else update(1,1,n,id[x],y);
    }
    return 0;
}
posted @ 2019-01-19 20:32  风浔凌  阅读(158)  评论(0编辑  收藏  举报