ExKMP

ExKMP

给定两个字符串 \(S\)\(T\) (长度分别为 \(n\)\(m\) ) , 下标从 \(0\) 开始 , 定义 \(extend[i]\) 等于 \(S[i]...S[n - 1]\)\(T\)\(LIS\) 的长度, 求出所有 \(extend[i]\) .

\(ExKMP\) 解决的就是这种一个串的所有后缀与模式串的最长公共前缀的问题, 算法与 \(KMP\) 的思想一样, 都是通过失配指针来线性解决问题.

定义一个数组 \(nxt\) , \(nxt[i]\) 表示 \(T[i]...T[m - 1]\)\(T\) 的最长公共前缀的长度, 我们可以通过将 \(T\) 与自己匹配得到.

\(p\) 表示以 \(a\) 为起始位置匹配的第一个失配位置, 也就是 \(S[a...p - 1] == T[0...p - a - 1]\) .

我们从头开始匹配, 直到第一个不匹配的字符, 这就是 \(p\) , 然后我们处理 \(extend[i], i \in (a, p)\) .

首先, 我们知道, \(S[i...p - 1] == T[i - a...p - a - 1], T[i - a...i - a + nxt[i - a]] == T[0...nxt[i - a]]\) , 定义 \(len = \min(p - i - 1, nxt[i - a])\) , 因为 \(T[i - a...i - a + len] == T[i - a...i - a + len]\) , 所以 \(S[i...i + len] == T[i - a...i - a + len]\) . 这样我们就只需要比较一下 \(i + nxt[i]\)\(p - 1\) 的关系, 也就是求 \(len\) 就行了.

接下来我们分类讨论:

  1. \(i + nxt[i - a] < p\) , 说明 \(S[i + nxt[i - a] + 1] \ne T[nxt[i - a] + 1]\) , 直接 \(extend[i] = nxt[i - a]\) , 因为我们没有往后比较的必要了, 后面第一位就不同了.
  2. \(i + nxt[i - a] \ge p\) , 说明 \(S[p...i + nxt[i - a]] \ne T[p - a...i - a + nxt[i - a]]\) , 但是 \(S[p...i + nxt[i - a]]\) 有可能与 \(T[p - i...nxt[i - a]]\) 相等, 这样我们就直接从 \(S[p]\) 开始与 \(T[p - i]\) 比较, 一边比较一边跳 \(p\)\(a\), 因为我们的 \(p\) 是以 \(a\) 为起始位置的第一个失配位置, 所以 \(a = i\) , \(p\) 就跟着匹配的位置跳就好了.

\(nxt\) 数组的时候注意一下, 我们下标是从 \(0\) 开始的, 这里要从 \(1\) 开始匹配, 因为自己和自己从 \(0\) 匹配肯定都是相同的, 就没有意义了. 然后把 \(extend\) 换成 \(nxt\) , 做一遍完全相同的匹配就好了.

\(code:\)

#include <cstdio>
#include <cstring>
const int N = 2e7 + 5;
typedef long long ll;
int nxt[N], extend[N];
ll ans1, ans2;
char s[N], t[N];
void get_nxt() { //处理 nxt 数组
    int len = strlen(t);
    nxt[0] = len;
    for (int a = 1; a < len; a++) {
        int p = a; //从 a 开始匹配
        while (t[p] == t[p - a] && p < len) p++; //p 为第一个失配位置
        nxt[a] = p - a;
        for (int i = a + 1; i < p; i++) { //处理 (a, p) 的 nxt
            if (i + nxt[i - a] >= p) {
                while (t[p] == t[p - i] && p < len) p++; //继续向后匹配, 顺便处理 p
                nxt[i] = p - i;
                a = i; //a 与 p 必须一起跳, 否则就不符合我们的定义了
            }
            else nxt[i] = nxt[i - a];
        }
    }
    for (int i = 0; i < len; i++) ans1 ^= 1ll * (i + 1) * (nxt[i] + 1);
}
void get_extend() { //处理 extend 数组, 具体内容与处理 nxt 的函数几乎相同
    int lens = strlen(s), lent = strlen(t);
    for (int a = 0; a < lens; a++) {
        int p = a;
        while (s[p] == t[p - a] && p < lens && p - a < lent ) p++;
        extend[a] = p - a;
        for (int i = a + 1; i < p; i++) {
            if (i + nxt[i - a] >= p) {
                while (s[p] == t[p - i] && p < lens && p - i < lent) p++;
                extend[i] = p - i;
                a = i;
            }
            else extend[i] = nxt[i - a];
        }
    }
    for (int i = 0; i < lens; i++) ans2 ^= 1ll * (i + 1) * (extend[i] + 1);
}
int main() {
    scanf("%s%s", s, t);
    get_nxt(); get_extend();
    printf("%lld\n%lld", ans1, ans2);
    return 0;
}

理解起来不算很难, 借助画图可以非常容易的理解, 代码实现有一点需要注意一下, 就是 \(a\)\(p\) 一定要一起跳, 这样才符合我们的定义, 否则后面都是错的, 这里卡了我整整一天.

posted @ 2021-08-21 16:36  sshadows  阅读(31)  评论(0编辑  收藏  举报