poj 3450 Corporate Identity【后缀数组,求多个串的最长公共子串】

题意就是求多个串的最长公共子串,把多个串合并成一个串,中间用没有出现过的字符隔开,然后对这个串求后缀数组,然后二分枚举答案。

View Code
/* 后缀数组倍增算法
 * 并且计算了height[], height[i] = LCP(i-1, i),  LCP(i, j)=lcp(suffix(sa[i]), suffix(sa[j])) 
 * 时间复杂度:N*logN
 * */
//poj 3450.
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

const int MAXM = 1000000 + 10;
const int LOG_MAXM = 20;

int buc[MAXM];
int X[MAXM], Y[MAXM];
int rank[MAXM], height[MAXM], sa[MAXM];

void cal_height(int *r, int n)
{
    int i, j, k = 0; // height[i] = LCP(i-1, i)
    //rank[sa[0]] = 0;
    for (i = 1; i <= n; ++i) rank[sa[i]] = i; // sa[0] 是添加的0字符,0字符排第0位,所以可以忽略
    for (i = 0; i < n; height[rank[i++]] = k) // h[i] = height[rank[i]], h[i] >= h[i-1] - 1 
        for (k?k--:0, j = sa[rank[i] - 1]; r[i+k] == r[j+k]; ++k) ;
}

bool cmp(int *r, int a, int b, int l) 
{
    return r[a] == r[b] && r[a+l] == r[b+l];
}

void suffix(int *r, int n, int m=128) //字符串r末尾是0,所以长度应该是n+1,即原来基础上多加0
{
    int i, l, p, *x = X, *y  = Y, *t;

    for (i = 0; i < m; ++i) buc[i] = 0;
    for (i = 0; i < n; ++i) buc[ x[i] = r[i] ]++;
    for (i = 1; i < m; ++i) buc[i] += buc[i - 1];
    for (i = n - 1; i >= 0; --i) sa[ --buc[x[i]] ] = i;

    for (l = 1, p = 1; p < n; m = p, l <<= 1) {
        for (p = 0, i = n - l; i < n; ++i) y[p++] = i; // 末尾l个子串没有l长的第二关键字
        for (i = 0; i < n; ++i) { // 根据第二关键字,存第一关键字的位置
            if (sa[i] >= l) // 保证有第一关键字
                y[p++] = sa[i] - l; // 记录第一关键字的位置
        }
        for (i = 0; i < m; ++i) buc[i] = 0; //根据第一关键字排序 
        for (i = 0; i < n; ++i) buc[ x[y[i]] ]++;
        for (i = 1; i < m; ++i) buc[i] += buc[i - 1];
        for (i = n - 1; i >= 0; --i) sa[ --buc[ x[y[i]] ] ] = y[i];
        for (t = x, x = y, y = t, x[ sa[0] ] = 0, i = 1, p = 1; i < n; ++i) {//为下次排序准备2*l长的子串的rank值
            x[ sa[i] ] = cmp(y, sa[i-1], sa[i], l) ? p-1 : p++; // 新的rank值
        }
    }
    cal_height(r, n - 1);
}

int best[MAXM][LOG_MAXM];
void init_rmq(int n)
{
    for (int i = 1; i <= n; ++i) best[i][0] = height[i];
    for (int l = 1; (1 << l) <= n; ++l) {
        int limit = n - (1 << l) + 1;
        for (int i = 1; i <= limit; ++i) {
            best[i][l] = min(best[i][l-1], best[i + (1 << (l-1))][l-1]);
        }
    }
}
int lcp(int a, int b) // 询问a, b后缀的最长前缀
{
    a = rank[a], b = rank[b];
    if (a > b) swap(a, b);
    ++a; // 因为height[i]记录的是LCP(i-1, i),所以要++a;
    int l = 0;
    for (; (1 << l) <= b - a + 1; ++l) ;
    --l;
    return min(best[a][l], best[b-(1<<l)+1][l]);
}

char str[MAXM];  //输入的串。
char s[MAXM];   //输出的串。
int num[MAXM];
int ans[MAXM];  //把每个串分成不同的区域,中间用一个没有出现的字符隔开。
bool vis[4001]; //这个地方最开始开了MAXM的数组,提交上去TLE了,因为数组太大,下面要不断的memset,所以超时。

bool check(int mid, int n, int nt)
{
    int tot = 0;
    memset(vis, 0, sizeof(vis));
    for(int i = 2; i <= n; i++) {
        //这个地方理解了很长时间,因为mid表示的是长度,所以height[i] < mid表示当前排名i的串和排名i-1的
        //串的公共前缀长度小于我们要枚举的长度,所以不成立,所以要对tot清零,并初始化标记数组。
        if(height[i] < mid) {
            tot = 0;
            memset(vis, 0, sizeof(vis));
            continue;
        }

        if( !vis[ans[sa[i-1]]] ) vis[ans[sa[i-1]]] = 1, tot++;
        if( !vis[ans[sa[i]]] ) vis[ans[sa[i]]] = 1, tot++;

//        当tot == nt时,说明找到了长度为mid的串,根据height数组的性质,如果有多个长度相等的子串,
//        第一次找到的为字典序最小的,所以符合题意。
        if(tot == nt) {
            for(int j = 0; j < mid; j++) {
                s[j] = num[sa[i] + j] + 'a' - 1;
            }
            s[mid] = '\0';
            return 1;
        }
    }
    return 0;
}

int main()
{
    // input
    int nt;
    while(scanf("%d", &nt), nt) {
        int n = 0;
        int tmp = 27; //小写字符从1到26,所以没出现过的字符从27开始标记。
        for(int i = 0; i < nt; i++) {
            scanf("%s", str);
            int len = strlen(str);
            for(int j = 0; j < len; j++) {
                num[n] = str[j] - 'a' + 1;  //不能从0-25,因为下面num[n] = 0;
                ans[n++] = i;
            }
            ans[n] = tmp;
            num[n++] = tmp++;
        }
        num[n] = 0;
        suffix(num, n + 1, tmp); //tmp是总共出现的字符的个数,并且+1; 
        init_rmq(n);

        int left, right;
        left = 0;
        right = strlen(str);

        //用二分的方法找最长公共子串,mid为长度。
        int flag = 0;
        while(left <= right) {
            int mid = (left+right)/2;
            if( check(mid, n, nt) ) {
                left = mid + 1;
                flag = mid;
            } else {
                right = mid - 1;
            }
        }

        if(flag) {
            printf("%s\n", s);
        } else {
            printf("IDENTITY LOST\n");
        }
    }
    return 0;
}
posted @ 2012-08-06 16:37  小猴子、  阅读(289)  评论(0编辑  收藏  举报