Common Substrings POJ - 3415 (后缀数组 + 单调栈)

A substring of a string T is defined as:

T(i, k)=TiTi+1...Ti+k-1, 1≤i≤i+k-1≤|T|.
Given two strings A, B and one integer K, we define S, a set of triples (i, j, k):

S = {(i, j, k) | k≥K, A(i, k)=B(j, k)}.
You are to give the value of |S| for specific A, B and K.

Input
The input file contains several blocks of data. For each block, the first line contains one integer K, followed by two lines containing strings A and B, respectively. The input file is ended by K=0.

1 ≤ |A|, |B| ≤ 105
1 ≤ K ≤ min{|A|, |B|}
Characters of A and B are all Latin letters.

Output
For each case, output an integer |S|.

Sample Input
2
aababaa
abaabaa
1
xx
xx
0
Sample Output
22
5

题意:求串1和串2中所有长度 \(\geq k\) 的相同字串个数。
思路:首先要把两个字符串连接起来,并用一个没出现的字符分隔, 假设 Sa[j] 属于串1,Sa[i ~ (j - 1)] 属于串2, 且 Height[i + 1] 到 Height[j] 单调递减且 Height[j] \(\geq k\)
根据后缀数组的性质, Sa[i ~ (j - 1)] 和 Sa[j] 的贡献为 (j - i) * (Height[i] - k + 1), 所以我们可以维护一个单调递增的栈, 将所有连续递减的Height[i]压缩成一个块,
记录这个块的最小 Height 和属于串2的后缀的个数,当 Sa[i] 属于串1时,加上前面所有块的贡献,为了快速计算贡献和,要维护一个前缀和。
所以我们for两次, 一次求串1的后缀和该后缀前面所有串2后缀的贡献,一次求串2的后缀和该后缀前面所有串1后缀的贡献。

#include <cstdio>
#include <algorithm>
#include <queue>
#include <stack>
#include <string>
#include <math.h>
#include <string.h>
#include <map>
#include <iostream>
using namespace std;
const int maxn = 5e5 + 50;
const int mod = 20090717;
int INF = 1e9;
typedef long long LL;
typedef pair<int, int> pii;
#define fi first
#define se second
int Sa[maxn], Height[maxn], Tax[maxn], Rank[maxn], tp[maxn], a[maxn], n, m, minLen;
char str[maxn];

void Rsort(){
    for(int i = 0; i <= m; i++) Tax[i] = 0;
    for(int i = 1; i <= n; i++) Tax[Rank[tp[i]]]++;
    for(int i = 1; i <= m; i++) Tax[i] += Tax[i - 1];
    for(int i = n; i >= 1; i--) Sa[Tax[Rank[tp[i]]]--] = tp[i];
}

int cmp(int *f, int x, int y, int w){
    if(x + w > n || y + w > n) return 0; // 注意防止越界,多组输入的时候这条必须有
    return f[x] == f[y] && f[x + w] == f[y + w];
}

void Suffix(){
    for(int i = 1; i <= n; i++) Rank[i] = a[i], tp[i] = i;
    m = 200, Rsort();
    for(int w = 1, p = 1, i; p < n; w += w, m = p){
        for(p = 0, i = n - w + 1; i <= n; i++) tp[++p] = i;
        for(i = 1; i <= n; i++) if(Sa[i] > w) tp[++p] = Sa[i] - w;
        Rsort(), swap(Rank, tp), Rank[Sa[1]] = p = 1;
        for(int i = 2; i <= n; i++)  Rank[Sa[i]] = cmp(tp, Sa[i], Sa[i - 1], w) ? p : ++p;
    }
    int j, k = 0;
    for(int i = 1; i <= n; Height[Rank[i++]] = k){
        for(k = k ? k - 1 : k, j = Sa[Rank[i] - 1]; a[i + k] == a[j + k]; ++k);
    }
}

int dpmi[maxn][60];
void RMQ(){
    for(int i = 1; i <= n; i++){
        dpmi[i][0] = Height[i];
    }
    for(int j = 1; (1 << j) <= n; j++){
        for(int i = 1; i + (1 << j) - 1 <= n; i++){
            dpmi[i][j] = min(dpmi[i][j - 1], dpmi[i + (1 << (j - 1))][j - 1]);
        }
    }
}

int QueryMin(int l, int r){
    int k = log2(r - l + 1);
    return min(dpmi[l][k], dpmi[r - (1 << k) + 1][k]);
} 

int QueryLcp(int i, int j){
    if(i > j) swap(i, j);
    i++;
    return QueryMin(i, j);
}


int Find(int i){
    int le = i, ri = n;
    int res = 0;
    while(le <= ri){
        int mid = (le + ri) >> 1;
        if(QueryLcp(i, mid) >= minLen){
            le = mid + 1;
            res = max(res, mid);
        } else {
            ri = mid - 1;
        }
    }
    return res;
}
struct qnode
{
    int cnt, h;
    LL sum;
} stk[maxn];
int main(int arg, char const *argv[])
{
    while(1){
        scanf("%d", &minLen);
        if(!minLen) break;
        scanf("%s", str + 1);
        int len = strlen(str + 1) + 1;
        str[len] = '0';
        scanf("%s", str + len + 1);
        n = strlen(str + 1);
        for(int i = 1; i <= n; i++) a[i] = str[i];
        Suffix();
        LL ans = 0;
        int top = 0;
        for(int i = 2; i <= n; i++){
            int cnt = 0;
            if(Height[i] < minLen) {
                top = 0;
                continue;
            }
            while(top && Height[i] <= stk[top].h) {
                cnt += stk[top--].cnt;
            }
            if(Sa[i - 1] > len) cnt++;
            stk[++top].cnt = cnt;
            stk[top].h = Height[i];
            stk[top].sum = stk[top - 1].sum + (stk[top].h - minLen + 1) * stk[top].cnt;
            if(Sa[i] <= len) ans += stk[top].sum;
        }

        top = 0;
        for(int i = 2; i <= n; i++){
            int cnt = 0;
            if(Height[i] < minLen) { // 这步必须有
                top = 0;
                continue;
            }
            while(top && Height[i] <= stk[top].h) {
                cnt += stk[top--].cnt;
            }
            if(Sa[i - 1] <= len) cnt++;
            stk[++top].cnt = cnt;
            stk[top].h = Height[i];
            stk[top].sum = stk[top - 1].sum + (stk[top].h - minLen + 1) * stk[top].cnt;
            if(Sa[i] > len) ans += stk[top].sum;
        }
        printf("%I64d\n", ans);
    }
    return 0;
}

posted @ 2020-07-11 11:30  从小学  阅读(143)  评论(0编辑  收藏  举报