LCA离线Tarjan,树上倍增入门题

  离线Tarjian,来个JVxie大佬博客最近公共祖先LCA(Tarjan算法)的思考和算法实现还有zhouzhendong大佬的LCA算法解析-Tarjan&倍增&RMQ(其实你们百度lca前两个博客就是。。。)

  LCA是最近公共祖先的意思,在上图的话像4和5的最近公共祖先就是2,而4和7的最近公共祖先是1,从某种意义上讲如果不怕超时的话,每次直接暴力搜索是可以找到每两个节点的最近公共祖先的,不过红红的TLE不好看,要想生活过得去,还是得看点AC的绿。

  而Tarjan求lca是离线算法,也就是对于所有的询问需要全输入完处理之后再输出答案,而不是输入一组询问输出相应的答案,JVxie大佬的讲解中有很详细的模拟过程,我就不重复了,就简单讲讲自己的理解。

  首先,我们需要保存所有询问,像建边一样我们记录下每个节点与其有询问关系的点,像上图的询问4和5的最近公共祖先的话,我们可以在4和5间类似建边一样加上一个关系,用我自己的代码来说就是(我是习惯链式前向星建边)

struct Side{
  int v,c,ne;
}S[2*N],Q[2*M];

建边:

void adds(int u,int v,int c)
{
  S[sn].v=v;
  S[sn].c=c;
  S[sn].ne=heads[u];
  heads[u]=sn++;
}

建询问关系:

void addq(int u,int v,int id)
{
  Q[qn].v=v;
  Q[qn].c=id;
  Q[qn].ne=headq[u];
  headq[u]=qn++;
}

边里面的c就代表边的权值,而询问关系里的c就代表它是第几组询问

  然后就是dfs了,我们先对每个点深搜也就是对它的子树遍历一遍,并且用并查集把它的子节点归到它这里,然后当这个点的子树都遍历完后,我们就去遍历和它有询问关系的点,如果和它有询问关系的某个点已经被搜索过了的话,那么某个点当前并查集所归属到的那个点就是他们的最近公共祖先。

  因为是按深搜的顺序遍历树,一个节点一开始是归属到它父节点的,如果它的兄弟节点或者它的父节点和它有询问关系的话,那么最近公共节点就是它的父节点,而如果是其他不跟它再同一个分支的节点跟它有询问关系的话,当它的父节点搜索完,它也会跟着它的父节点归属到它的爷爷节点,一直往上归属到两个节点的公共节点。光说太空洞,做题理解,引用大佬推荐的两道题

CODEVS 2370 小机房的树 传送门

  中文题,大意就是树上两个节点的最低路径,也就是他们到他们的最近公共祖先的距离,直接上代码了

 1 #include<cstdio>
 2 const int N=52013,M=75118;
 3 struct Side{
 4     int v,c,ne;
 5 }S[2*N],Q[2*M];
 6 int sn,qn,heads[N],headq[N],fa[N],vis[N],dis[N],ans[M];
 7 void init(int n) 
 8 {
 9     sn=qn=0;
10     for(int i=0;i<=n;i++)
11     {
12         fa[i]=i;//并查集,每个节点一开始自己是自己父亲 
13         dis[i]=0;
14         vis[i]=0;
15         heads[i]=headq[i]=-1;
16     }
17 }
18 void adds(int u,int v,int c)
19 {
20     S[sn].v=v;
21     S[sn].c=c;
22     S[sn].ne=heads[u];
23     heads[u]=sn++;
24 } 
25 void addq(int u,int v,int id)
26 {
27     Q[qn].v=v;
28     Q[qn].c=id;//记录下它的第几个询问 
29     Q[qn].ne=headq[u];
30     headq[u]=qn++;
31 }
32 int gui(int x){
33     return fa[x]==x ? x : fa[x]=gui(fa[x]);
34 }
35 void bing(int x,int y)
36 {
37     int gx=gui(x),gy=gui(y);
38     if(gx!=gy)
39         fa[gy]=gx;
40 }//并查集,我们老师说最简单的算法,嘤嘤嘤 
41 void dfs(int u,int f,int len)
42 {
43     dis[u]=len;
44     for(int i=heads[u];i!=-1;i=S[i].ne)
45     {
46         int v=S[i].v;
47         if(v!=f)
48         {
49             dfs(v,u,len+S[i].c);
50             bing(u,v);//把子节点归属到当前节点 
51             vis[v]=1;//标记当前节点以及访问过了 
52         }
53     }
54     for(int i=headq[u];i!=-1;i=Q[i].ne)
55     {//处理每个和它有关的询问 
56         int v=Q[i].v,id=Q[i].c;
57         if(vis[v])//当有关系的节点已经被访问过,处理该条询问 
58             ans[id]=dis[u]+dis[v]-2*dis[gui(v)];
59     }
60 }
61 int main()
62 {
63     int n,m,u,v,c;
64     scanf("%d",&n);
65     init(n);
66     for(int i=1;i<n;i++)
67     {
68         scanf("%d%d%d",&u,&v,&c);
69         adds(u,v,c);
70         adds(v,u,c);
71     }
72     scanf("%d",&m);
73     for(int i=0;i<m;i++)
74     {
75         scanf("%d%d",&u,&v);
76         addq(u,v,i);
77         addq(v,u,i);
78         //询问也要建双向的,因为不确定谁先访问到 
79     }
80     dfs(0,-1,0);
81     for(int i=0;i<m;i++)
82         printf("%d\n",ans[i]);
83     return 0;
84 }
搞基的虫子~~

  为什么前面说最后对树形dp有点理解呢,因为一般不会问最近公共祖先,而是问路径,所以我们需要记录下每个节点到根节点的距离,然后两个节点的距离就是它们到根节点的距离再减去两倍的它们最近公共祖先到根节点的距离(因为这段距离在两个节点到根节点的距离中重复了),比如上图,求4到5的距离的话,我们记录的是根节点,也就是1到4和1到5的距离,然后很明显,在1到4和1到5中都包含了1到2(4和5最近公共祖先)的距离,而我们要求4到5,是不需要1到2这段的距离的。

ZOJ 3195 Design the city 传送门

  题目大意也是求树上的最短路径,不过是求三个节点的,也是先上代码

 1 #include<cstdio>
 2 const int N=50018,M=70018;
 3 struct Side{
 4     int v,c,ne;
 5 }S[2*N],Q[6*M];
 6 int sn,qn,heads[N],headq[N],fa[N],vis[N],dis[N],ans[M];
 7 void init(int n)
 8 {
 9     sn=qn=0;
10     for(int i=0;i<=n;i++)
11     {
12         fa[i]=i;
13         dis[i]=0;
14         vis[i]=0;
15         heads[i]=headq[i]=-1;
16     }
17 }
18 void adds(int u,int v,int c)
19 {
20     S[sn].v=v;
21     S[sn].c=c;
22     S[sn].ne=heads[u];
23     heads[u]=sn++;
24 }
25 void addq(int u,int v,int id)
26 {
27     Q[qn].v=v;
28     Q[qn].c=id;
29     Q[qn].ne=headq[u];
30     headq[u]=qn++;
31 }
32 int gui(int x){
33     return fa[x]==x ? x : fa[x]=gui(fa[x]);
34 }
35 void bing(int x,int y)
36 {
37     int gx=gui(x),gy=gui(y);
38     if(gx!=gy)
39         fa[gy]=gx;
40 }
41 void dfs(int u,int f)
42 {
43     for(int i=heads[u];i!=-1;i=S[i].ne)
44     {
45         int v=S[i].v;
46         if(v!=f)
47         {
48             dis[v]=dis[u]+S[i].c;
49             dfs(v,u);
50             bing(u,v);
51             vis[v]=1;
52         }
53     }
54     for(int i=headq[u];i!=-1;i=Q[i].ne)
55     {
56         int v=Q[i].v,id=Q[i].c;
57         if(vis[v])
58             ans[id]+=dis[u]+dis[v]-2*dis[gui(v)];
59     }
60 }
61 int main()
62 {
63     int n,m,u,v,c,is=0;
64     while(~scanf("%d",&n))
65     {
66         if(is) 
67             printf("\n");
68         is=1;
69         init(n);
70         for(int i=1;i<n;i++)
71         {
72             scanf("%d%d%d",&u,&v,&c);
73             adds(u,v,c);
74             adds(v,u,c);
75         }
76         scanf("%d",&m);
77         int x,y,z;
78         for(int i=0;i<m;i++)
79         {
80             ans[i]=0;
81             scanf("%d%d%d",&x,&y,&z);
82             addq(x,y,i);//每两个点间都要建个询问关系 
83             addq(x,z,i);
84             addq(y,x,i);
85             addq(y,z,i);
86             addq(z,x,i);
87             addq(z,y,i);
88         }
89         dfs(0,-1);
90         for(int i=0;i<m;i++)
91             printf("%d\n",ans[i]/2);
92     }
93     return 0;
94 }
3个点就了不起啊

  其实和上题基本都一样,就是在求三个节点时,我们要分别求出两两节点间的最短距离,然后把三个结果相加起来除2就是答案。因为像a,b,c三个节点,我们求出a到b,和b到c,以及a到c的距离,所需要求的距离就会重复了一遍,比如上图求3,4,5的距离,也就是要求3到1,1到2,2到4,2到5这些边,而我们求出3到4是,3到1,1到2,2到4,求出4到5是,4到2,2到5,求出3到5就是,3到1,1到2,2到5,可以看到我们需要求的边都刚好多走了一遍。所以只要用lcd求出两两的最短路,3个加起来除2就是3个节点间的最短路。

  夜深了,树上倍增明天再更、树上倍增,我个人感觉它和树链剖分的关系就像树状数组和线段树,有些操作树上倍增做不了,但相同操作,比如求lca,树上倍增的效率更高,而且实现更简单。先放个大佬博客

  树上倍增的写法和应用(详细讲解,新手秒懂)Saramanda大佬的,简短精辟。

  首先,树上倍增是基于二进制的思想,我们把一个数,比如10我们可以写成10=23+21,也就是1010,那么如果我们保存一个fa[i][j]表示第i个节点的第2j个祖先的话,那么我们求一个节点A的第10个祖先的话,我们就可以先找到它第2个祖先B,然后B再找到它的第8个祖先C,C也就是A的第十个祖先。因为一个节点和它的父亲肯定是在同一条链上,那么它的祖先关系是累加的,就像它的第十个组先就是它第二个祖先的第八个主线,我们把要找的第K个转换为二进制的话,我们就可以根据对位i上有没有1来直接上升到第(1<<i)个祖先,整个算法的复杂度就变成了log级的了,简单的描述就是

找节点u的第K个祖先:

for(int i=0;i<=fn;i++)//fn为最多可能有2fn个祖先,一般不超过20多

{//从fn到0也是可以的

  if((1<<i)&k)

    u=fa[u][i];

}

比如K是10的话,1001,其实就是在每个1的位置上升也就是呈2(1<<i)倍数增加

  既然可以求第k个祖先了,那我们怎么用树上倍增来求lca呢,不同于前面的tarjan算法,树上倍增求lca是在线算法,也就是输入一组询问就可以求相应的结果,所以我们要在dfs时先把fa这个记录祖先的数组还有deep记录深度的数组处理出来。具体操作就是,在每次遍历每个节点u的子树之前,我们先从它第1个祖先开始,依次判断它的第1,2,4,8,2fn-1是否已经存在,已经存在的话,那么fa[u][i]=fa[fa[u][i-1]][i-1],因为2i=2i-1+2i-1嘛,如果第2i-1个祖先都还没存在,那么就不用再向下找第2i,2i+1个等等,代码实现就是

void dfs(int u)
{
  for(int i=1;i<=fn;i++)
  {
   int f=fa[u][i-1];
   if(f==-1)//如果第2i-1个祖先不存在,直接可以结束了
    break;
   fa[u][i]=fa[f][i-1];//否则就是倍增,2i=2i-1+2i-1
  }
  for(int i=head[u];i!=-1;i=S[i].ne)//然后遍历和它相连的点
  {
    int v=S[i].v;
    if(v!=fa[u][0])//fu[u][0]就是u的父节点
    {
      fa[v][0]=u;//v的父节点就是u
      deep[v]=deep[u]+1;//记录深度
      dis[v]=dis[u]+S[i].w;//到根节点的距离
      dfs(v);
    }
  }
}

  处理出deep和fa数组后,然后怎么求lca呢,再次用到这个图,我们如果求7和8的lca的话,deep记录的深度,我们从0开始的话,deep也相当于它有多少个祖先,像deep[8]=3,deep[7]=2,我们可以发现7和8不在同一层次,那么我们就要把它们统一层次,也就把8提上来。deep[8]和deep[7]相差了1,我们就找到8的第1个祖先5,这时5和7就同一层次了。同一层次的话,那么它们的最近公共祖先肯定就都是两人第一个相同的第2i个祖先,这个统一层次的过程就是

if(deep[u]<deep[v])//让u是深度大的节点
{
  int t=u;
  u=v;v=t;
}
int disd=deep[u]-deep[v];//深度差
for(int i=0;i<=fn;i++)//把u提到第disd个祖先
  if((1<<i)&disd)
    u=fa[u][i];
if(u==v)//同一层次时u和v是同一节点直接可以返回了
  return u;

  当统一层次而u!=v时,我们继续从fn到0开始往下找,当找到某个不相等的节点时,我们就移到双双移到那个位置。就像i从fn到2时,fa[5][i],fa[7][i]都一直相等,都是不存在的节点(可以用-1代表),而fa[5][1]=fa[7][1]为1这些都不处理,当i=0时,fa[5][0]=2,fa[7][0]=3,两者不相等,所以5移到2,7移到5,最后fa[2][0]或者fa[3][0]也就是1,是5和7的最近公共祖先。

  为什么呢?就假如我们找同一层次的a和b的公共祖先,因为i是从fn到0,从2fn到20个祖先的找,当a和b第2i个祖先c和d不相等时,我们就把a移到c,b移到d,因为i是逐渐减小的,那么最后c和d就是a和b最近公共祖先下一层的那两个祖先。

  那从0到fn的遍历找第一个相同的祖先可不可以呢?答案是不可以的,因为我们是以20,21,22,23,这样2i级遍历的,那么如果要找的最近公共祖先是a和b的第6个祖先,那么fa[a][2],fa[b][2]不相同,接着就是到fa[a][3]和fa[b][3]了,直接把6跳过了。而从fn到0的话,c=fa[a][2]和d=fa[b][2]不相同,a移到c,b移到d,然后fa[c][1]和fa[d][1]相同跳过,e=fa[c][0]和f=fa[d][0]不同,c再转到e,d转到f,最终a,b最近公共祖先就是fa[e][0]或者fa[f][0],也就是22+20+20

  之前的两道题用树上倍增来实现的话就是

 1 #include<cstdio>
 2 const int N=52118,F=32;
 3 struct Side{
 4     int v,w,ne;
 5 }S[2*N];
 6 int sn,fn,head[N],dis[N],deep[N],fa[N][F];
 7 void init(int n)
 8 {
 9     sn=0;
10     fn=20;
11     for(int i=0;i<=n;i++)
12     {
13         head[i]=-1;
14         dis[i]=0;
15         deep[i]=0;
16         for(int j=0;j<=fn;j++)
17             fa[i][j]=-1;
18     }
19 }
20 void add(int u,int v,int w)
21 {
22     S[sn].v=v;
23     S[sn].w=w;
24     S[sn].ne=head[u];
25     head[u]=sn++;
26 }
27 void dfs(int u)
28 {
29     for(int i=1;i<=fn;i++)
30     {
31         int f=fa[u][i-1];
32         if(f==-1)
33             break;
34         fa[u][i]=fa[f][i-1];
35     }
36     for(int i=head[u];i!=-1;i=S[i].ne)
37     {
38         int v=S[i].v;
39         if(v!=fa[u][0])
40         {
41             fa[v][0]=u;
42             deep[v]=deep[u]+1;
43             dis[v]=dis[u]+S[i].w;
44             dfs(v);
45         }
46     }
47 }
48 int lca(int u,int v)
49 {
50     if(deep[u]<deep[v])
51     {
52         int t=u;
53         u=v;v=t;
54     }//让u为深度较深的那个节点 
55     int disd=deep[u]-deep[v];
56     for(int i=0;i<=fn;i++)
57         if((1<<i)&disd)
58             u=fa[u][i];
59     //将u移到和v相同深度的那个祖先 
60     if(u==v)//相同层次时是同一节点 
61         return u;
62     for(int i=fn;i>=0;i--) 
63         if(fa[u][i]!=fa[v][i])
64         {
65             u=fa[u][i];
66             v=fa[v][i];
67         }
68     //找最近公共祖先下一层的u,v的祖先 
69     return fa[u][0];
70 }
71 int main()
72 {
73     int n,m,u,v,w;
74     while(~scanf("%d",&n))
75     {
76         init(n);
77         for(int i=1;i<n;i++)
78         {
79             scanf("%d%d%d",&u,&v,&w);
80             add(u,v,w);
81             add(v,u,w);
82         }
83         dfs(0);
84         scanf("%d",&m);
85         while(m--)
86         {
87             scanf("%d%d",&u,&v);
88             int f=lca(u,v);
89             printf("%d\n",dis[u]+dis[v]-2*dis[f]);
90         }
91     }
92     return 0;
93 }
精力旺盛的虫子
#include<cstdio>
#include<cmath>
const int N=52118,F=32; 
struct Side{
    int v,w,ne;
}S[2*N];
int sn,fn,head[N],dis[N],deep[N],fa[N][F];
void init(int n)
{
    sn=0;
    fn=(int)log(n);
    for(int i=0;i<=n;i++)
    {
        head[i]=-1;
        dis[i]=0;
        deep[i]=0;
        for(int j=0;j<=fn;j++)
            fa[i][j]=-1;
    }
}
void add(int u,int v,int w)
{
    S[sn].v=v;
    S[sn].w=w;
    S[sn].ne=head[u];
    head[u]=sn++;
}
void dfs(int u)
{
    for(int i=1;i<=fn;i++)
    {
        int f=fa[u][i-1];
        if(f==-1)
            break;
        fa[u][i]=fa[f][i-1];
    }
    for(int i=head[u];i!=-1;i=S[i].ne)
    {
        int v=S[i].v;
        if(v!=fa[u][0])
        {
            fa[v][0]=u;
            deep[v]=deep[u]+1;
            dis[v]=dis[u]+S[i].w;
            dfs(v);
        }
    }
}
int lca(int u,int v)
{
    if(deep[u]<deep[v])
    {
        int t=u;
        u=v;v=t;
    }
    int disd=deep[u]-deep[v];
    for(int i=0;i<=fn;i++)
        if((1<<i)&disd)
            u=fa[u][i];
    if(u==v)
        return u;
    for(int i=fn;i>=0;i--)
        if(fa[u][i]!=fa[v][i])
        {
            u=fa[u][i];
            v=fa[v][i];
        }
    return fa[u][0];
}
int main()
{
    int n,m,u,v,w,x,is=0;
    while(~scanf("%d",&n))
    {
        if(is)
            printf("\n");
        is=1;
        init(n);
        for(int i=1;i<n;i++)
        {
            scanf("%d%d%d",&u,&v,&w);
            add(u,v,w);
            add(v,u,w);
        }
        dfs(0);
        scanf("%d",&m);
        while(m--)
        {
            scanf("%d%d%d",&u,&v,&x);
            int f,ans=0;
            //计算任意两点间的距离 
            f=lca(u,v); 
            ans+=dis[u]+dis[v]-2*dis[f];
            f=lca(u,x);
            ans+=dis[u]+dis[x]-2*dis[f];
            f=lca(v,x);
            ans+=dis[v]+dis[x]-2*dis[f];
            printf("%d\n",ans/2);
        }
    }
    return 0;
}
3个点了不起

  这里推荐个树上倍增的模板题

Query on a tree IISPOJ - QTREE2 

  题目大意是,一棵树有两种询问,一直就是问a到b的最短路径,另一种就是问a到b路径上第k个节点是哪个?

  前一种操作很简单就是dis[a]+dis[b]-2*dis[lca(a,b)],而第二种的话,我们求出lca(a,b),那么a跟lca(a,b)的深度差就是这条路径上有多少个祖先,所以k-1<=这个深度差的话,那么就是找a的第k-1个祖先,否则就是找b的b跟lca(a,b)的深度差+1-(k-a跟lca(a,b)的深度差)个祖先,

  1 #include<cstdio>
  2 #include<algorithm>
  3 using namespace std;
  4 const int N=10118,F=32;
  5 struct Side{
  6     int v,w,ne;
  7 }S[2*N];
  8 char op[8];
  9 int sn,fn,head[N],dis[N],deep[N],fa[N][F];
 10 void init(int n)
 11 {
 12     sn=0;fn=21;
 13     for(int i=0;i<=n;i++)
 14     {
 15         dis[i]=0;
 16         head[i]=-1;
 17         deep[i]=0;
 18         for(int j=0;j<=fn;j++)
 19             fa[i][j]=-1;
 20     }
 21 }
 22 void add(int u,int v,int w)
 23 {
 24     S[sn].w=w;
 25     S[sn].v=v;
 26     S[sn].ne=head[u];
 27     head[u]=sn++;
 28 }
 29 void dfs(int u)
 30 {
 31     for(int i=1;i<=fn;i++)
 32     {
 33         int f=fa[u][i-1];
 34         if(f==-1)
 35             break;
 36         fa[u][i]=fa[f][i-1];
 37     }
 38     for(int i=head[u];i!=-1;i=S[i].ne)
 39     {
 40         int v=S[i].v;
 41         if(v!=fa[u][0])
 42         {
 43             fa[v][0]=u;
 44             dis[v]=dis[u]+S[i].w;
 45             deep[v]=deep[u]+1;
 46             dfs(v);
 47         }
 48     }
 49 }
 50 int kth(int u,int k)//找u的第k个祖先 
 51 {
 52     for(int i=0;i<=fn;i++)
 53         if((1<<i)&k)
 54             u=fa[u][i];
 55     return u;
 56 }
 57 int lca(int u,int v)//找u,v的最近公共祖先 
 58 {
 59     if(deep[u]<deep[v])
 60         swap(u,v);
 61     int disd=deep[u]-deep[v];
 62     u=kth(u,disd);
 63     if(u==v)
 64         return u;
 65     for(int i=fn;i>=0;i--)
 66         if(fa[u][i]!=fa[v][i])
 67         {
 68             u=fa[u][i];
 69             v=fa[v][i];
 70         }
 71     return fa[u][0];
 72 }
 73 int main()
 74 {
 75     int t,n,u,v,w;
 76     scanf("%d",&t);
 77     while(t--)
 78     {
 79         scanf("%d",&n);
 80         init(n);
 81         for(int i=1;i<n;i++)
 82         {
 83             scanf("%d%d%d",&u,&v,&w);
 84             add(u,v,w);
 85             add(v,u,w);
 86         }
 87         dfs(1);
 88         while(scanf("%s",op)&&op[1]!='O')
 89         {
 90             scanf("%d%d",&u,&v);
 91             int f=lca(u,v);
 92             if(op[0]=='K')
 93             {
 94                 scanf("%d",&w);
 95                 int du=deep[u]-deep[f],dv=deep[v]-deep[f];
 96                 //du:u和lca的深度差,dv:v和lca的深度差,
 97                 //之所以要-1,+1是因为u,v本身也是路径上的节点 
 98                 if(w-1<=du)
 99                     printf("%d\n",kth(u,w-1));
100                 else
101                     printf("%d\n",kth(v,dv+1-(w-du)));
102             }
103             else
104                 printf("%d\n",dis[u]+dis[v]-2*dis[f]);
105         }
106         printf("\n");
107     }
108     return 0;
109 }
倍增倍增倍增
posted @ 2019-03-27 01:19  新之守护者  阅读(351)  评论(0编辑  收藏  举报