【模板】扩展 kmp (exkmp) / Z 函数

posted on 2022-08-08 23:29:53 | under 模板 | source

求出一个字符串 \(s\) 的每个后缀与原串的 LCP。

首先由显然的 SAM 做法。考虑线性。

考虑维护区间 \([l,r]\) 表示 \([l,r]=[1,r-l+1]\) 是最右的匹配段。考虑新的 \(i\),如果满足 \(l\leq i\leq r\),则 \(i\) 可以直接取 \(i-l+1\) 的答案继续扩展,否则继续扩展。最后更新区间。

点击查看代码

#include <algorithm>
#include <cassert>
#include <cstdio>
#include <cstring>
#include <vector>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define debug(...) void(0)
#endif
typedef long long LL;
int n, m, z[1 << 26];
char a[1 << 26];
void exkmp(int len) {
  z[1] = len;
  debug("a = %s\n", a + 1);
  for (int i = 2, l = 0, r = 0; i <= len; i++) {
    if (i <= r) z[i] = min(z[i - l + 1], r - i + 1);
    while (i + z[i] <= len && a[1 + z[i]] == a[i + z[i]]) ++z[i];
    if (i + z[i] - 1 > r) r = i + z[l = i] - 1;
    debug("z[%d] = %d\n", i, z[i]);
  }
}
char buf[1 << 26];
int main() {
  scanf("%s%s", buf + 1, a + 1);
  n = strlen(a + 1), m = strlen(buf + 1);
  a[n + 1] = '?';
  for (int i = 1; i <= m; i++) a[n + i + 1] = buf[i];
  exkmp(n + m + 1);
  LL sum1 = n + 1, sum2 = 0;
  for (int i = 2; i <= n; i++) sum1 ^= 1ll * i * (z[i] + 1);
  for (int i = 1; i <= m; i++) sum2 ^= 1ll * i * (z[i + n + 1] + 1);
  printf("%lld\n", sum1);
  printf("%lld\n", sum2);
  return 0;
}

点击查看代码

#include <algorithm>
#include <cstdio>
#include <cstring>
using namespace std;
typedef long long LL;
int n, m, z[20000010], p[20000010];
char a[20000010], b[20000010];
void getz(char *a, int n) {
  z[1] = n;
  for (int i = 2, l = 0, r = 0; i <= n; i++) {
    if (i <= r) z[i] = min(z[i - l + 1], r - i + 1);
    while (i + z[i] <= n && a[i + z[i]] == a[z[i] + 1]) z[i]++;
    if (i + z[i] - 1 > r) r = i + z[l = i] - 1;
  }
}
void match(char *a, int n, char *b, int m) {
  b[m + 1] = '?';
  for (int i = 1, l = 0, r = 0; i <= n; i++) {
    if (i <= r) p[i] = min(z[i - l + 1], r - i + 1);
    while (i + p[i] <= n && a[i + p[i]] == b[p[i] + 1]) p[i]++;
    if (i + p[i] - 1 > r) r = i + p[l = i] - 1;
  }
}
int main() {
  //	#ifdef LOCAL
  //	 	freopen("input.in","r",stdin);
  //	#endif
  scanf("%s%s", a + 1, b + 1), n = strlen(a + 1), m = strlen(b + 1);
  getz(b, m), match(a, n, b, m);
  LL ans1 = 0, ans2 = 0;
  for (int i = 1; i <= n; i++) ans1 ^= 1ll * i * (p[i] + 1);
  for (int i = 1; i <= m; i++) ans2 ^= 1ll * i * (z[i] + 1);
  printf("%lld\n%lld\n", ans2, ans1);
  return 0;
}

posted @ 2023-10-20 10:08  caijianhong  阅读(20)  评论(0编辑  收藏  举报