poj3415(后缀数组)

poj3415

题意

给定两个字符串,给出长度 \(m\) ,问这两个字符串有多少对长度大于等于 \(m\) 且完全相同的子串。

分析

首先连接两个字符串 A B,中间用一个特殊符号分割开。
按照 \(sa\) 的顺序(即枚举 \(height\) 值),进行分组,那么有公共前缀长大于等于 \(m\) 的都分到了一组,对于某一组,后缀串可能来自于 A 也可能来自于 B,那么对于 A 找前面的 B 串,对于 B 找前面的 A 串,如果某两个后缀串的公共前缀长为 \(l(l \geqslant m)\),那么显然会有 \(l - m + 1\) 对子串。
注意到这个性质: 对于两个后缀串 j 和 k,设 \(rnk[j] < rnk[k]\) ,LCP长度为 \(height[rnk[j]+1], height[rnk[j]+2], ... , height[rnk[k]]\) 中的最小值。
维护一个单调递增的栈(保证栈顶最大)可以用一个二维数组表示(\(q[][2]\)),一个是栈,一个是某个数的个数。
举个例子,如果连续的 \(height\) 值为 \(2 \ 3 \ 4\)\(m = 2\),前三个为 A 串,那么 \(2 \ 3 \ 4\) 全部入栈,且计算对答案的贡献 \(sum\)(不是直接加到答案上),即 \((2-2+1) + (3-2+1) + (4-2+1)\) ,到 B 串时,答案就加上了这个值,但是如果后面还有一个 B 串且 \(height\)\(3\),那么就要弹栈,且减去 \(sum\) 值多的那部分(前面多算了),栈里 \(4\) 的数量为 \(1\),所以 \(sum = sum - (4 - 3) * 1\) ,且栈里 \(3\) 的数量变为了 \(2\)\(4\) 对应的 A 串对于后面串提供的贡献减小了(注意前面的性质),所以\(4\) 变为了 \(3\) ),答案加上 \(sum\)

code

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int MAXN = 2e5 + 10;
const int INF = 1e9;
char s[MAXN];
int sa[MAXN], t[MAXN], t2[MAXN], c[MAXN], n; // n 为 字符串长度 + 1,即最后一位为数字 0
int rnk[MAXN], height[MAXN];
// 构造字符串 s 的后缀数组。每个字符值必须为 0 ~ m-1
void build_sa(int m) {
    int i, *x = t, *y = t2;
    for(i = 0; i < m; i++) c[i] = 0;
    for(i = 0; i < n; i++) c[x[i] = s[i]]++;
    for(i = 1; i < m; i++) c[i] += c[i - 1];
    for(i = n - 1; i >= 0; i--) sa[--c[x[i]]] = i;
    for(int k = 1; k <= n; k <<= 1) {
        int p = 0;
        for(i = n - k; i < n; i++) y[p++] = i;
        for(i = 0; i < n; i++) if(sa[i] >= k) y[p++] = sa[i] - k;
        for(i = 0; i < m; i++) c[i] = 0;
        for(i = 0; i < n; i++) c[x[y[i]]]++;
        for(i = 0; i < m; i++) c[i] += c[i - 1];
        for(i = n - 1; i >= 0; i--) sa[--c[x[y[i]]]] = y[i];
        swap(x, y);
        p = 1;
        x[sa[0]] = 0;
        for(i = 1; i < n; i++)
            x[sa[i]] = y[sa[i - 1]] == y[sa[i]] && y[sa[i - 1] + k] == y[sa[i] + k] ? p - 1 : p++;
        if(p >= n) break;
        m = p;
    }
}
void getHeight() {
    int i, j, k = 0;
    for(i = 0; i < n; i++) rnk[sa[i]] = i;
    for(i = 0; i < n - 1; i++) {
        if(k) k--;
        j = sa[rnk[i] - 1];
        while(s[i + k] == s[j + k]) k++;
        height[rnk[i]] = k;
    }
}
char s2[MAXN];
int q[MAXN][2];
int main() {
    int m;
    while(~scanf("%d", &m) && m) {
        scanf("%s%s", s, s2); // A 、B串
        int l = strlen(s), l2 = strlen(s2);
        s[l++] = '#';
        for(int i = 0; i < l2; i++) s[i + l] = s2[i];
        s[l + l2] = 0;
        n = l + l2 + 1;
        build_sa(128);
        getHeight();
        ll ans = 0, sum = 0;
        int top = 0;
        // 在 B 串前找 A
        for(int i = 2; i < n; i++) {
            int cnt = 0;
            if(height[i] < m) {
                top = 0; sum = 0;
                continue;
            }
            if(sa[i - 1] < l) {
                cnt++;
                sum += height[i] - m + 1;
            }
            while(top && q[top - 1][0] >= height[i]) {
                top--;
                sum -= (q[top][0] - height[i]) * q[top][1];
                cnt += q[top][1];
            }
            q[top][0] = height[i]; q[top++][1] = cnt;
            if(sa[i] >= l) ans += sum;
        }
        // 在 A 串前找 B
        sum = 0; top = 0;
        for(int i = 2; i < n; i++) {
            int cnt = 0;
            if(height[i] < m) {
                top = 0; sum = 0;
                continue;
            }
            if(sa[i - 1] >= l) {
                cnt++;
                sum += height[i] - m + 1;
            }
            while(top && q[top - 1][0] >= height[i]) {
                top--;
                sum -= (q[top][0] - height[i]) * q[top][1];
                cnt += q[top][1];
            }
            q[top][0] = height[i]; q[top++][1] = cnt;
            if(sa[i] < l) ans += sum;
        }
        printf("%lld\n", ans);
    }
    return 0;
}
posted @ 2017-07-22 21:53  ftae  阅读(234)  评论(0编辑  收藏  举报