关于lca

倍增$ST$表

预处理复杂度 $O(n \log n)$

单次查询复杂度 $O(\log n)$

$RMQ$倍增的思想。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
inline int read()
{
    int f=1,ans=0;char c;
    while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
    while(c>='0'&&c<='9'){ans=ans*10+c-'0';c=getchar();}
    return ans*f;
}
int q,a[500001],fa[500001][21];
struct node{
    int x;int y;
    int nex;
}ss[1000001];
int head[500001];
int cnt=1;
void add(int a,int b)
{
    ss[cnt].x=a,ss[cnt].y=b;
    ss[cnt].nex=head[a],head[a]=cnt++;
    return;
}
int deep[500001];
void dfs(int f,int fath)
{
    deep[f]=deep[fath]+1;
//    minn[f][0]=min(a[f],a[fath]);
    fa[f][0]=fath;
    for(int i=1;(1<<i)<=deep[f];i++) fa[f][i]=fa[fa[f][i-1]][i-1];
    for(int i=head[f];i!=-1;i=ss[i].nex)
        if(ss[i].y!=fath) dfs(ss[i].y,f);
}
int n,m,s;
int lca(int u,int v)
{
    if(deep[u]<deep[v]) swap(u,v);
    for(int i=20;i>=0;i--)
        if(deep[u]-(1<<i)>=deep[v]) u=fa[u][i];
    if(u==v) return u;
    for(int i=20;i>=0;i--)
    {
        if(fa[u][i]==fa[v][i]) continue;
        else u=fa[u][i],v=fa[v][i];
    }
    return fa[u][0];
}
int main()
{
    memset(head,-1,sizeof(head));
    n=read(),m=read(),s=read();
    for(int i=1;i<n;i++)
    {
        int u=read(),v=read();
        add(u,v),add(v,u);
    }
    dfs(s,0);
    for(int i=1;i<=m;i++)
    {
        int u=read(),v=read();
        printf("%d\n",lca(u,v));
    }
}/*
5 5 4
3 1
2 4
5 1
1 4
3 5
1 2 
4 5*/
View Code

 

$Tarjan$算法

 离线操作,总复杂度约$O(n+q)$

主要就是若要求两点之间$lca$,思想是$lca(u,v)$封锁了这两颗子树,记录一下当前节点是否回溯过,同时用并查集维护一下当前父亲。

#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
using namespace std;
inline int read(){
    int f=1,ans=0;char c=getchar();
    while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
    while(c>='0'&&c<='9'){ans=ans*10+c-'0';c=getchar();}
    return f*ans;
}
const int N=40001;
const int Q=201;
struct node{
    int u,v,w,nex;
}x[N<<1];
struct Node{
    int u,v,id,nex;
}query[Q<<1];
int head1[N],head[N],cnt1,cnt,T;
void add(int u,int v,int w){
    x[cnt].u=u,x[cnt].v=v,x[cnt].w=w,x[cnt].nex=head[u],head[u]=cnt++;
}
void Add(int u,int v,int id){
    query[cnt1].u=u,query[cnt1].v=v,query[cnt1].id=id,query[cnt1].nex=head1[u],head1[u]=cnt1++;
}
int calc[Q],vis[N],fa[N];
void init(){memset(head,-1,sizeof(head)),memset(head1,-1,sizeof(head1)),cnt=0,cnt1=0,memset(vis,0,sizeof(vis));}
int find(int x){
    if(fa[x]==x) return x;
    return fa[x]=find(fa[x]);
}
int dis[N];
void dfs(int f,int fath){
    for(int i=head[f];i!=-1;i=x[i].nex){
        if(x[i].v==fath) continue;
        dis[x[i].v]=dis[f]+x[i].w;
        dfs(x[i].v,f);
        fa[x[i].v]=f;
    }
    for(int i=head1[f];i!=-1;i=query[i].nex){
        int u=query[i].u,v=query[i].v,id=query[i].id;
        if(vis[v]){
            int lca=find(v);
            calc[id]=(dis[u]+dis[v])-2*dis[lca];
        }
    }vis[f]=1;
}
int main(){
    T=read();
    while(T--){
        init();
        int n=read(),q=read();
        for(int i=1;i<=n;i++) fa[i]=i;
        for(int i=1;i<n;i++){int u=read(),v=read(),w=read();add(u,v,w),add(v,u,w);}
        for(int i=1;i<=q;i++){int u=read(),v=read();Add(u,v,i),Add(v,u,i);}
        dfs(1,0);
        for(int i=1;i<=q;i++) printf("%d\n",calc[i]);
    }
}
View Code

 

欧拉序列

预处理复杂度 $O(n \log n)$

单次查询复杂度 $O(1)$

与普通欧拉序不同,当每进入一个节点中,我们就将其编号记下,并且$lca(u,v)$是从$u$到$v$简单路径下深度最浅的点,就可以用$RMQ$维护即可。

#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#include<cmath>
using namespace std;
inline int read(){
    int f=1,ans=0;char c=getchar();
    while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
    while(c>='0'&&c<='9'){ans=ans*10+c-'0';c=getchar();}
    return f*ans;
}
const int N=1000001;
int deep[N],n,root,cnt,st[N][21],num,in[N],out[N],head[N],q;
struct node{
    int u,v,nex;
}x[N<<1];
void add(int u,int v){
    x[cnt].u=u,x[cnt].v=v,x[cnt].nex=head[u],head[u]=cnt++;
}
void dfs(int f,int fath){
    in[f]=++num,deep[f]=deep[fath]+1;
    st[num][0]=f;
    for(int i=head[f];i!=-1;i=x[i].nex){
        if(x[i].v==fath) continue;
        dfs(x[i].v,f);
        st[++num][0]=f;
    }
}
void init(){
    for(int j=1;j<=log2(num);j++)
        for(int i=1;i+(1<<j)<=num;i++){
            int s1=st[i][j-1],s2=st[i+(1<<(j-1))][j-1];
            if(deep[s1]<deep[s2]) st[i][j]=st[i][j-1];
            else st[i][j]=st[i+(1<<(j-1))][j-1];
        }
}
int query(int u,int v){
    int l=in[u],r=in[v];
    if(l>r) swap(l,r);
    int k=log2(r-l+1);
    int s1=st[l][k],s2=st[r-(1<<k)+1][k];
    if(deep[s1]<deep[s2]) return s1;
    return s2;
}
int main(){
//    freopen("3.in","r",stdin);
    memset(head,-1,sizeof(head));
    n=read(),q=read(),root=read();
    for(int i=1;i<n;i++){int u=read(),v=read();add(u,v),add(v,u);}
    dfs(root,0);
    init();
    while(q--){
        int u=read(),v=read();
        printf("%d\n",query(u,v));
    }
}
View Code

 树链剖分

预处理复杂度 $O(n \log n)$

单次查询复杂度 $O(\log n)$

处理好轻重边后两点往上跳,一直到两点在一条重链上,深度最短的即为$lca$

#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#include<climits>
using namespace std;
inline int read(){
    int f=1,ans=0;char c=getchar();
    while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
    while(c>='0'&&c<='9'){ans=ans*10+c-'0';c=getchar();}
    return f*ans;
}
const int N=500001;
int n,q,rt,size[N],cnt,deep[N],top[N],son[N],fa[N],head[N];
struct node{
    int u,v,nex;
}x[N<<1];
void add(int u,int v){
    x[cnt].u=u,x[cnt].v=v,x[cnt].nex=head[u],head[u]=cnt++;
}
struct LCA{
    void dfs1(int f,int fath){
        fa[f]=fath;
        size[f]=1;deep[f]=deep[fath]+1;
        for(int i=head[f];i!=-1;i=x[i].nex){
            if(x[i].v==fath) continue;
            dfs1(x[i].v,f);
            size[f]+=size[x[i].v];
            if(size[son[f]]<size[x[i].v]) son[f]=x[i].v;
        }return;
    }
    void dfs2(int f,int fath){
        if(son[f]){
            top[son[f]]=top[f];
            dfs2(son[f],f);
        }
        for(int i=head[f];i!=-1;i=x[i].nex){
            if(x[i].v==fath||x[i].v==son[f]) continue;
            top[x[i].v]=x[i].v;
            dfs2(x[i].v,f);
        }
    }
    int lca(int x,int y){
        int fx=top[x],fy=top[y];
        while(fx!=fy){
            if(deep[fx]<deep[fy]) swap(x,y),swap(fx,fy);
            x=fa[fx],fx=top[x];
        }
        if(deep[x]>deep[y]) swap(x,y);
        return x;
    }
}Q;
int main(){
    memset(head,-1,sizeof(head));
    n=read(),q=read(),rt=read();
    for(int i=1;i<n;i++){
        int u=read(),v=read();
        add(u,v),add(v,u);
    }
    Q.dfs1(rt,0),top[rt]=rt,Q.dfs2(rt,0);
    for(int i=1;i<=q;i++){
        int u=read(),v=read();
        printf("%d\n",Q.lca(u,v));
    }
}
View Code

 

posted @ 2018-12-27 22:40  siruiyang_sry  阅读(189)  评论(0编辑  收藏  举报