【JZOJ4918】最近公共祖先

Description

这里写图片描述
这里第 i 节点的权值为wi

Solution

考虑染黑的节点对所有节点的贡献。

如下图,假设要染黑A节点。
这里写图片描述
那么首先所有祖先为A的节点(红色框内的)的贡献对 wA 取最大值。

这时考虑以A节点的父亲(记为B)作为lca的贡献,那么显然,非A点所在子树(绿色框内的)的贡献都可以对 WB 取最大值。

然后再考虑B节点的父亲C,那么按照上面所说的,蓝色框内的贡献都对 WC 取最大值。

然后再往上……

于是我们按照dfs序用数据结构进行区间修改,单点查询即可。

但是,树的深度不可估量,这样做虽在随机数据上表现优良,但在深度十分大的树上效率就低下了。

一个可行优化是,我们每更新要染黑祖先的点,给它打个标记,等某次更新完一个点后,我们检查是否有标记,如果有,说明该点以上的都被更新过了,就没必要再次更新。那么总共的更新次数不超过 2n

Code

#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#define fo(i,j,k) for(int i=j;i<=k;i++)
#define fd(i,j,k) for(int i=j;i>=k;i--)
#define rep(i,x) for(int i=ls[x];i;i=nx[i])
#define N 100010
#define M 200010
using namespace std;
int to[M],nx[M],ls[N],num=0;
int w[N],fa[N],dfn[N],pos[N],fr[N],en[N],tot=0;
bool bz[N];
bool vis[N];
struct node{
    int w,lz;
}tr[N*4];
void link(int x,int y)
{
    num++;
    to[num]=y;
    nx[num]=ls[x];
    ls[x]=num;
}
void find(int x)
{
    dfn[x]=++tot;
    pos[tot]=x;
    fr[x]=tot;
    rep(i,x)
    {
        int v=to[i];
        if(v!=fa[x])
        {
            fa[v]=x;
            find(v);
        }
    }
    en[x]=tot;
}
void put(int v)
{
    if(!tr[v].lz) return;
    tr[v*2].w=max(tr[v*2].w,tr[v].lz);
    tr[v*2+1].w=max(tr[v*2+1].w,tr[v].lz);
    tr[v*2].lz=max(tr[v*2].lz,tr[v].lz);
    tr[v*2+1].lz=max(tr[v*2+1].lz,tr[v].lz);
    tr[v].lz=0;
}
void update(int v){
    tr[v].w=max(tr[v*2].w,tr[v*2+1].w);
}
void change(int v,int l,int r,int x,int y,int t)
{
    if(l==x && r==y)
    {
        tr[v].w=max(tr[v].w,t);
        tr[v].lz=max(t,tr[v].lz);
        return;
    }
    int mid=(l+r)/2;
    put(v);
    if(y<=mid) change(v*2,l,mid,x,y,t);
    else if(x>mid) change(v*2+1,mid+1,r,x,y,t);
    else
    {
        change(v*2,l,mid,x,mid,t);
        change(v*2+1,mid+1,r,mid+1,y,t);
    }
    update(v);
}
int find(int v,int l,int r,int x)
{
    if(l==r) return tr[v].w;
    int mid=(l+r)/2;
    put(v);
    if(x<=mid) return find(v*2,l,mid,x);
    else return find(v*2+1,mid+1,r,x);
}
int main()
{
    freopen("lca.in","r",stdin);
    freopen("lca.out","w",stdout);
    int n,m;
    scanf("%d %d",&n,&m);
    fo(i,1,n) scanf("%d",&w[i]);
    fo(i,2,n)
    {
        int x,y;
        scanf("%d %d",&x,&y);
        link(x,y);
        link(y,x);
    }
    find(1);
    bool tf=false;
    while(m--)
    {
        char s[10];
        int u;
        scanf("%s %d",s,&u);
        if(s[0]=='M')
        {
            tf=true;
            if(vis[u]) continue;
            vis[u]=true;
            change(1,1,n,fr[u],en[u],w[u]);
            int last=u;
            for(int x=fa[u];x;x=fa[x])
            {
                change(1,1,n,dfn[x],dfn[x],w[x]);
                rep(i,x)
                {
                    int v=to[i];
                    if(v!=fa[x] && v!=last) change(1,1,n,fr[v],en[v],w[x]);
                }
                last=x;
                if(bz[x]) break;
                bz[x]=true; 
            }
        }
        else printf("%d\n",(tf?find(1,1,n,dfn[u]):-1));
    }
}
posted @ 2016-12-16 20:06  sadstone  阅读(29)  评论(0编辑  收藏  举报