HDU6031 Innumerable Ancestors 倍增 - 题意详细概括 - 算法详解

去博客园看该题解

题目

查看原题 - HDU6031 Innumerable Ancestors

题目描述

  有一棵有n个节点的有根树,根节点为1,其深度为1,现在有m个询问,每次询问给出两个集合A和B,问LCA(x,y)(x∈A,y∈B)的深度最大为多少。

输入描述

  有多组数据(数据组数<=5)

  对于每一组数据,首先2个数n,m,表示有根树的节点个数和询问个数。然后n-1行,每行2个数a,b表示节点a和节点b之间存在直接的连边;接下去2m行,每两行,分别描述当前询问的集合A和集合B;对于一个集合,用一行来描述,该行第一个数K表示集合元素的个数,后面K个数表示集合中的元素。

输出描述

  一个整数,表示LCA(x,y)(x∈A,y∈B)的最大深度。

数据范围

  n,m<=100000, 1<=a,b<=n, ΣK<=100000, 1<=集合中的元素<=n

 

题解

  问最大深度,那么我们思考是否可以二分答案。

  当然可以,本题的条件满足二分答案的前提,LCA基本的性质还是比较明显的。(假设a和b深度一样)设anst[x][y]为节点x往上走y步到达的祖先,对于一个k,如果anst[a][k]==anst[b][k],那么对于k'(k'>k),一定有anst[a][k']==anst[b][k'];对于一个k,如果anst[a][k]!=anst[b][k],那么对于k'(k'<k),一定有anst[a][k']!=anst[b][k'],而且LCA(a,b)=LCA(anst[a][k],anst[b][k])。

  二分答案深度d完成之后,那么就剩下了编一个子程序判定的事情了。

  那么如果判定呢?

  已知祖先深度,那么就知道了每一个点所对应的祖先了是吧?那么,判断是否有公共祖先,其实就是判断A集合所对应的祖先集合与B集合所对应的祖先集合是否有交集——因为ΣK<=100000, 所以对于每一个集合元素找出它的某一深度的祖先这个复杂度貌似还是不够,ΣK*n应该会超(如果有人用ΣK*n的判定复杂度过了本题, 跪求留代码) 。那么我们要更快的找到这个祖先,那么是什么?倍增啊!

  fa[i][j]表示与节点i深度差为2^j的i的祖先,那么不难写出转移方程:

  fa[i][0]=father[i],fa[i][j]=fa[fa[i][j-1]][j-1] (father[i]表示节点i的父亲节点)

  So,求某一深度的祖先就是和倍增求LCA的前一半类似的了。

  至于两个集合判断交集,就是排个序,然后两个指针扫过去就可以了。

  注意: 在求祖先时,要首先把那些不合法的祖先过滤掉; 在判断交集的时候,要注意边界情况!

 

代码

 

#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <cstdio>
#include <cmath>
using namespace std;
const int N=100005,M=N*2,rt=1;
struct Edge{
    int cnt,y[M],nxt[M],fst[N];
    void set(){
        cnt=0;
        memset(y,0,sizeof y);
        memset(nxt,0,sizeof nxt);
        memset(fst,0,sizeof fst);
    }
    void add(int a,int b){
        y[++cnt]=b,nxt[cnt]=fst[a],fst[a]=cnt;
    }
}e;
int n,m,depth[N],fa[N][20],ta,a[N],tb,b[N],ansta[N],anstb[N];
void build(int prev,int rt){
    fa[rt][0]=prev,depth[rt]=depth[prev]+1;
    for (int i=1;(1<<i)<=depth[rt];i++)
        fa[rt][i]=fa[fa[rt][i-1]][i-1];
    for (int i=e.fst[rt];i;i=e.nxt[i])
        if (e.y[i]!=prev)
            build(rt,e.y[i]);
}
int get_kth_anst(int p,int k){
    for (int i=k,j=0;i>0;i>>=1,j++)
        if (i&1)
            p=fa[p][j];
    return p;
}
bool check(int d){
    int at=0,bt=0;
    for (int i=1;i<=ta;i++)
        if (depth[a[i]]>=d)
            ansta[++at]=get_kth_anst(a[i],depth[a[i]]-d);
    for (int i=1;i<=tb;i++)
        if (depth[b[i]]>=d)
            anstb[++bt]=get_kth_anst(b[i],depth[b[i]]-d);
    if (at==0||bt==0)
        return 0;
    int pa=1,pb=1;
    sort(ansta+1,ansta+at+1);
    sort(anstb+1,anstb+bt+1);
    if (ansta[1]==anstb[1])
        return 1;
    while (pa<=at&&pb<=bt){
        while (pa<=at&&ansta[pa]<anstb[pb])
            pa++;
        if (pa>at)
            break;
        if (ansta[pa]==anstb[pb])
            return 1;
        while (pb<=bt&&ansta[pa]>anstb[pb])
            pb++;
        if (pb>bt)
            break;
        if (ansta[pa]==anstb[pb])
            return 1;
    }
    return 0;
}
int main(){
    while (~scanf("%d%d",&n,&m)){
        e.set();
        for (int i=1,a,b;i<n;i++)
            scanf("%d%d",&a,&b),e.add(a,b),e.add(b,a);
        depth[0]=-1;
        build(0,rt);
        while (m--){
            scanf("%d",&ta);
            for (int i=1;i<=ta;i++)
                scanf("%d",&a[i]);
            scanf("%d",&tb);
            for (int i=1;i<=tb;i++)
                scanf("%d",&b[i]);
            int le=0,ri=n-1,mid,ans=0;
            while (le<=ri){
                mid=(le+ri)>>1;
                if (check(mid))
                    le=mid+1,ans=mid;
                else
                    ri=mid-1;
            }
            printf("%d\n",ans+1);
        }
    }
    return 0;
}
  

 

为了方便大家找茬,特地附上一份造数据的PASCAL代码,用于对拍。

var
    t, i: longint;
function min(a, b: longint): longint;
    begin
        if (a > b) then
            exit(b);
        exit(a);
    end;
procedure make_list(n ,m: longint);
    var
        i, j: longint;
    begin
        write(m, ' ');
        j := 0;
        for i := 1 to m do 
        begin
            j := j + random(n - j - m + i) + 1;
            write(j, ' ');
        end;
        writeln;
    end;
procedure mkdata;
    const 
        maxn = 150;
        maxm = 150;
        add = 40;
    var
        n, m, i, j, x, y, a, b: longint;
    begin
        n := random(maxn) + 1;
        m := random(maxm) + 1;
        writeln(n, ' ', m);
        for i := 2 to n do 
        begin
            x := i;
            y := random(i - 1) + 1;
            if (random(2) = 1) then
                writeln(x, ' ', y)
            else
                writeln(y, ' ', x);
        end;
        writeln;
        for i := 1 to m do 
        begin
            a := min(random(maxn div m + add)+2, n);
            b := min(random(maxn div m + add)+2, n);
            make_list(n, a);
            make_list(n, b);
        end;
        writeln;
    end;
begin
    assign(output, 'anst.in');
    rewrite(output);
    randomize;
    t := random(2) + 1;
    for i := 1 to t do
        mkdata;
    close(output);
end.

 

posted @ 2017-07-30 20:33  zzd233  阅读(450)  评论(0编辑  收藏  举报