ZJOI2008 树的统计

题目描述

题解:

其实就是单点修改,树链查max+sum。

没啥好说的,树剖+线段树搞一搞就好了。

代码:

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define N 30050
#define ll long long
inline int rd()
{
    int f=1,c=0;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){c=10*c+ch-'0';ch=getchar();}
    return f*c;
}
const ll inf = 0x3f3f3f3f3f3f3f3fll;
int n,q,hed[N],cnt;
ll w[N];
struct EG
{
    int to,nxt;
}e[2*N];
void ae(int f,int t)
{
    e[++cnt].to = t;
    e[cnt].nxt = hed[f];
    hed[f] = cnt;
}
int fa[N],siz[N],top[N],son[N],dep[N];
void dfs1(int u,int f)
{
    siz[u]=1;
    fa[u]=f;
    dep[u]=dep[f]+1;
    for(int j=hed[u];j;j=e[j].nxt)
    {
        int to = e[j].to;
        if(to==f)continue;
        dfs1(to,u);
        siz[u]+=siz[to];
        if(siz[to]>siz[son[u]])son[u]=to;
    }
}
int tin[N],pla[N],tim;
void dfs2(int u,int tp)
{
    top[u] = tp,tin[u]=++tim,pla[tim]=u;
    if(son[u])
    {
        dfs2(son[u],tp);
        for(int j=hed[u];j;j=e[j].nxt)
        {
            int to = e[j].to;
            if(to==fa[u]||to==son[u])continue;
            dfs2(to,to);
        }
    }
}
int get_lca(int x,int y)
{
    while(top[x]!=top[y])
    {
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        x=fa[top[x]];
    }
    return dep[x]<dep[y]?x:y;
}
struct segtree
{
    ll vx[N<<2],vs[N<<2];
    void update(int u)
    {
        vx[u] = max(vx[u<<1],vx[u<<1|1]);
        vs[u] = vs[u<<1]+vs[u<<1|1];
    }
    void build(int l,int r,int u)
    {
        if(l==r)
        {
            vs[u]=vx[u]=w[pla[l]];
            return ;
        }
        int mid = (l+r)>>1;
        build(l,mid,u<<1);
        build(mid+1,r,u<<1|1);
        update(u);
    }
    void insert(int l,int r,int u,int qx,ll d)
    {
        if(l==r)
        {
            vs[u]=vx[u]=d;
            return ;
        }
        int mid = (l+r)>>1;
        if(qx<=mid)insert(l,mid,u<<1,qx,d);
        else insert(mid+1,r,u<<1|1,qx,d);
        update(u);
    }
    ll query1(int l,int r,int u,int ql,int qr)//max
    {
        if(l==ql&&r==qr)return vx[u];
        int mid = (l+r)>>1;
        if(qr<=mid)return query1(l,mid,u<<1,ql,qr);
        else if(ql>mid)return query1(mid+1,r,u<<1|1,ql,qr);
        else return max(query1(l,mid,u<<1,ql,mid),query1(mid+1,r,u<<1|1,mid+1,qr));
    }
    ll query2(int l,int r,int u,int ql,int qr)//sum
    {
        if(l==ql&&r==qr)return vs[u];
        int mid = (l+r)>>1;
        if(qr<=mid)return query2(l,mid,u<<1,ql,qr);
        else if(ql>mid)return query2(mid+1,r,u<<1|1,ql,qr);
        else return query2(l,mid,u<<1,ql,mid)+query2(mid+1,r,u<<1|1,mid+1,qr);
    }
    ll q1(int u,int lim)
    {
        ll ret = -inf;
        int now = top[u];
        while(dep[now]>=dep[lim])
        {
            ret = max(ret , query1(1,n,1,tin[now],tin[u]) );
            u = fa[now] , now = top[u];
        }
        if(dep[u]>=dep[lim])
            ret = max(ret , query1(1,n,1,tin[lim],tin[u]) );
        return ret;
    }
    ll q2(int u,int lim)
    {
        ll ret = 0;
        int now = top[u];
        while(dep[now]>dep[lim])
        {
            ret += query2(1,n,1,tin[now],tin[u]);
            u = fa[now],now = top[u];
        }
        if(dep[u]>dep[lim])
            ret += query2(1,n,1,tin[lim]+1,tin[u]);
        return ret;
    }
}tr;
char ch[10];
int main()
{
    n=rd();
    for(int f,t,i=1;i<n;i++)
    {
        f=rd(),t=rd();
        ae(f,t),ae(t,f);
    }
    dfs1(1,0);
    dfs2(1,1);
    for(int i=1;i<=n;i++)w[i]=rd();
    tr.build(1,n,1);
    q=rd();
    for(int u,v,i=1;i<=q;i++)
    {
        scanf("%s",ch+1);
        if(ch[1]=='C')
        {
            u=rd(),v=rd();
            tr.insert(1,n,1,tin[u],v);
        }else if(ch[2]=='M')
        {
            u=rd(),v=rd();
            int lca = get_lca(u,v);
            ll ans = max(tr.q1(u,lca),tr.q1(v,lca));
            printf("%lld\n",ans);
        }else
        {
            u=rd(),v=rd();
            int lca = get_lca(u,v);
            ll ans = tr.query2(1,n,1,tin[lca],tin[lca]);
            ans += tr.q2(u,lca)+tr.q2(v,lca);
            printf("%lld\n",ans);
        }
    }
    return 0;
}

 

posted @ 2018-12-21 13:30  LiGuanlin  阅读(129)  评论(0编辑  收藏  举报