最近公共祖先LCA

倍增法

预处理O(nlogn),判断O(logn)

步骤:1.预处理:用bfs预处理每个点向上走2的k次方步对应的结点是谁(用fa[a][k]=fa[fa[a][k-1]][k-1]来求),同时预处理每个节点对应的深度depth。在bfs初始化时使用哨兵depth[0]=0,表示一                   个点跳的距离如果超过范围,那就把这个深度定为0。(0表示跳超了)

   2.将较深的点跳到较浅的点的那一层。(用二进制从大到小枚举法)

   3.将两点都跳到lca的下一层。

   4.返回fa[a][0](即lca)

 

#include <bits/stdc++.h>
using namespace std;
const int N=4e4+100,M=N*2;
int n,m;
int idx,h[N];
struct node
{
    int ne,to;
}e[M];

void add(int x,int y)         
{
    idx++;
    e[idx].ne=h[x];e[idx].to=y;h[x]=idx;
}

int fa[N][16],depth[N];
void bfs(int root)        //预处理每个点向上走2的k次方步能走到的结点 
{
    memset(depth,0x3f,sizeof depth);
    depth[0]=0,depth[root]=1;    //0为哨兵 ,根节点的深度为1 
    queue<int>q;
    q.push(root);
    
    while(q.size())
    {
        int x=q.front();q.pop();
        
        for(int i=h[x];i;i=e[i].ne)        //建树 
        {
            int y=e[i].to;
            if(depth[y]>depth[x]+1)
            {
                q.push(y);
                depth[y]=depth[x]+1;
                fa[y][0]=x;        //y的父节点为x
                 
                for(int k=1;k<=15;k++)        //预处理y点向上走2的k次方步到达的结点 
                    fa[y][k]=fa[fa[y][k-1]][k-1];
            }
        }
        
    }
}

int lca(int a,int b)        //倍增法 
{
    if(depth[a]<depth[b])swap(a,b);        //将a定义为较深的结点 
    
    for(int k=15;k>=0;k--)    //从大到小枚举(将两个结点的深度相同) 
    {
        //a对应2的k次方对应的结点的深度仍大于等于b的深度
        //那就将a跳到2的k次方对应的结点 
        if(depth[fa[a][k]]>=depth[b])    
            a=fa[a][k];
    }
    
    if(a==b)return a;        //如果a跳完之后的结点是b,那就返回(说明b是a、b的lca) 
    for(int k=15;k>=0;k--)    //将两点跳到lca的下一层 
    { 
        if(fa[a][k]!=fa[b][k])    //哨兵优化,如果跳的超过已有范围了,就是0 
        {
            a=fa[a][k];
            b=fa[b][k];
        }
    }
    
    return fa[a][0];  //返回lca(因为a目前在lca的下一层,所以fa[a][0]就是lca) 
    
}


int main()
{
    scanf("%d",&n);
    
    int root,a,b;
    for(int i=1;i<=n;i++)
    {
        scanf("%d%d",&a,&b);
        if(b==-1)root=a;
        else add(a,b),add(b,a);
    }
    
    bfs(root);        //预处理 
    scanf("%d",&m);
    while(m--)        //查询 
    {
        int x,y;
        scanf("%d%d",&x,&y);
        int p=lca(x,y);
        if(x==p)puts("1");
        else if(y==p)puts("2");
        else puts("0");
    }
    
    
    return 0;
}
View Code

 

  

posted @ 2022-06-02 10:40  wellerency  阅读(21)  评论(0编辑  收藏  举报