【算法】KMP 与 Z 函数

1. KMP#

1.1 算法简介#

可以做到线性匹配的快速匹配字符串的算法,并可以维护字符串最长公共前后缀,扩展出计算字符串周期。

在 OI 界 KMP 算法是字符串板块中很经典的算法,可以扩展出很多巧妙的解题技巧。

1.2 算法流程#

1.2.1 字符串匹配#

考虑 O(n2) 暴力的匹配,瓶颈在于每次匹配了很多重复却非法的字符导致效率很慢。

然后就是考虑如何优化,无非就是利用已经计算过的信息。

这里补充一个 最长公共前后缀 的概念:对于字符串 a,最长公共前后缀的长度为满足 a[1,i]=a[ni+1,n] 的最大的 i(i<n)。可以发现如果该定义包含本身则没有意义(长度一定为 n)。所以需要保证最长公共前后缀小于本身的长度。

一个匹配下分 s(模式串)和 t (匹配串),串长分别为 nm。记 kmpi(1im) 表示前缀 t[1,i] 的最长公共前后缀。

假设已经计算出了 kmpi。考虑如何匹配字符串。

s: caabaaabacb
t: aaba

此时 kmp:[0,1,0,1]

  • i=1j=1;此时两串匹配的为空串,ii+1
  • i=2j=1;此时两串可匹配成功,匹配了 aabaii+1,j=kmpj=kmp4=1
  • i=3j=2;此时两串匹配了 aii+1,j=kmpj=kmp1=0
  • i=4j=1;此时两串匹配的为空串,ii+1
  • i=5j=1;此时两串匹配了 aaii+1,j=kmpj=kmp2=1
  • i=6j=2;此时两串可匹配成功,匹配了 aabaii+1,j=kmpj=kmp4=1

如此匹配,我们便找到了所有的合法匹配位置,分别为 2,6

考虑分析时间复杂度,可以看成 (i,j) 对齐,按位匹配。所以整个流程 t 串一直在往前移动,时间复杂度 O(n+m)

1.2.2 计算 kmp 数组#

和匹配很像,相当于自己和自己匹配。所到之处的 j 记录为 kmpi 即可。

不过更好的理解是用两个指针 (i,j)。如果 sisj 匹配,则将 j 指针后移,否则跳 kmpj(可以保证 kmpj 已经算出)。然后 kmpi=j

这个过程一定也是 O(n) 的。分析与证明参考匹配。

1.3 算法实现#

计算 kmp 数组:

for (int i = 2, j = 0; i <= n; i++) {
  while(j && t[j + 1] != t[i]) j = nx[j];
  if(t[j + 1] == t[i]) j++;
  nx[i] = j;
}

匹配:

for (int i = 1, j = 0; i <= n; i++) {
  while(j && t[j + 1] != s[i]) j = nx[j];
  if(t[j + 1] == s[i]) j++;
  if(j == m) {
    cout << i - m + 1 << '\n';
  }
}

然后是 P3375 【模板】KMP,就把她俩整合一下即可。

#include<bits/stdc++.h>
#define int long long
#define For(i,l,r) for(int i=l;i<=r;++i)
#define FOR(i,r,l) for(int i=r;i>=l;--i)

using namespace std;

const int N = 1e6 + 10;

int n, m, nx[N];

char s[N], t[N];

signed main() {
  ios::sync_with_stdio(0);
  cin.tie(0), cout.tie(0);
  cin >> (s + 1) >> (t + 1);
  n = strlen(s + 1);
  m = strlen(t + 1);
  for (int i = 2, j = 0; i <= n; i++) {
    while(j && t[j + 1] != t[i]) j = nx[j];
    if(t[j + 1] == t[i]) j++;
    nx[i] = j;
  }
  for (int i = 1, j = 0; i <= n; i++) {
    while(j && t[j + 1] != s[i]) j = nx[j];
    if(t[j + 1] == s[i]) j++;
    if(j == m) {
      cout << i - m + 1 << '\n';
    }
  }
  For(i,1,m) cout << nx[i] << ' ';
  return 0;
}

1.4 扩展#

1.4.1 字符串周期#

UVA1328 Period

给定一个字符串 s,判断其前缀 s[1,i] 是否为周期字符串,并求出其周期长度和循环节数量。

可以发现一些性质:如果对于 s[1,i]imod(ikmpi)=0,则为周期字符串。

证明很简单。

image

按照这样的方式对于每一个小块染上不同的颜色。

image

这里红色段和黄色段相等。

image

这样每一个小段可以传递相等。这样就可以证明出其为周期字符串。

代码挂着
#include<bits/stdc++.h>
#define int long long
#define reg register 
#define For(i,l,r) for(reg int i=l;i<=r;++i)
#define FOR(i,r,l) for(reg int i=r;i>=l;--i)

using namespace std;

const int N = 1e6 + 10;

int n, kmp[N], id;

char s[N];

signed main() {
  ios::sync_with_stdio(0);
  cin.tie(0), cout.tie(0);
  while(cin >> n && n != 0) {
    For(i,1,n) kmp[i] = 0;
    cin >> (s + 1);
    for (int i = 2, j = 0; i <= n; ++i) {
      while(j && s[i] != s[j + 1]) j = kmp[j];
      if(s[i] == s[j + 1]) j++;
      kmp[i] = j;
    }
    ++id;
    cout << "Test case #" << id << '\n';
    For(i,1,n) {
      if(i % (i - kmp[i]) == 0 && kmp[i] != 0) {
        cout << i << ' ' << (i / (i - kmp[i])) << '\n';
      }
    } 
    cout << '\n';
  }
  return 0;
}

2. Z 函数#

2.1 算法流程#

有一个这样的问题:

给定两个字符串 a,b,你要求出两个数组:

  • bz 函数数组 z,即 bb 的每一个后缀的 LCP 长度。
  • ba 的每一个后缀的 LCP 长度数组 p

对于第一个 subtask 所求的数组 z,则为 Z 函数,可以用 扩展 KMP 算法(exKMP) 求得。

2.1.1 暴力求解#

很好想,就拿一对指针 (i,j) 去匹配,匹配记录,失配重置。

时间复杂度 O(n2)

2.1.2 Z-box 引入#

维护一个区间 [l,r]l=i,r=i+zi1,其中 r 为已知最大的合法右端点。

对于 ir,可以分为 zil+1<rl+1zil+1rl+1 两种情况。

  • zil+1<rl+1;此时 zi=zil+1
  • zil+1rl+1;此时 zi=ri+1,然后暴力匹配。

其他情况暴力匹配,然后更新 [l,r] 即可。

时间复杂度 O(n),这样可以做到线性了。

可以发现 exKMPManacher 的算法思想很像。

2.2 算法实现。#

计算 Z 函数:

z[1] = m;
for (reg int i = 2, l, r = 0; i <= m; ++i) {
  if(i <= r) z[i] = min(z[i - l + 1], r - i + 1);
  while(i + z[i] <= m && t[1 + z[i]] == t[i + z[i]]) z[i]++;
  if(i + z[i] - 1 > r) l = i, r = i + z[i] - 1;
}

计算 s,t 后缀的公共最长前缀。

把匹配 s 的数组换成 p,求法和 Z 函数的求法一样。

for (reg int i = 1, l, r = 0; i <= n; ++i) {
  if(i <= r) p[i] = min(z[i - l + 1], r - i + 1);
  while(1 + p[i] <= m && i + p[i] <= n && s[i + p[i]] == t[1 + p[i]]) p[i]++;
  if(i + p[i] - 1 > r) l = i, r = i + p[i] - 1;
}

然后是 【模板】扩展 KMP/exKMP(Z 函数)

#include<bits/stdc++.h>
#define int long long
#define reg register
#define For(i,l,r) for(reg int i=l;i<=r;++i)
#define FOR(i,r,l) for(reg int i=r;i>=l;--i)

using namespace std;

const int N = 2e7 + 10;

int n, m, z[N], p[N], ans1, ans2;

char s[N], t[N];

signed main() {
  ios::sync_with_stdio(0);
  cin.tie(0), cout.tie(0);
  cin >> (s + 1) >> (t + 1);
  n = strlen(s + 1), m = strlen(t + 1);
  z[1] = m;
  for (reg int i = 2, l, r = 0; i <= m; ++i) {
    if(i <= r) z[i] = min(z[i - l + 1], r - i + 1);
    while(i + z[i] <= m && t[1 + z[i]] == t[i + z[i]]) z[i]++;
    if(i + z[i] - 1 > r) l = i, r = i + z[i] - 1;
  }
  for (reg int i = 1, l, r = 0; i <= n; ++i) {
    if(i <= r) p[i] = min(z[i - l + 1], r - i + 1);
    while(1 + p[i] <= m && i + p[i] <= n && s[i + p[i]] == t[1 + p[i]]) p[i]++;
    if(i + p[i] - 1 > r) l = i, r = i + p[i] - 1;
  }
  For(i,1,m) ans1 = (ans1 ^ (i * (z[i] + 1)));
  For(i,1,n) ans2 = (ans2 ^ (i * (p[i] + 1)));
  cout << ans1 << '\n' << ans2 << '\n';
  return 0;
}

作者:Daniel-yao

出处:https://www.cnblogs.com/Daniel-yao/p/18554369

版权:本作品采用「署名-非商业性使用-相同方式共享 4.0 国际」许可协议进行许可。

posted @   Daniel_yzy  阅读(72)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 单线程的Redis速度为什么快?
· 展开说说关于C#中ORM框架的用法!
· Pantheons:用 TypeScript 打造主流大模型对话的一站式集成库
· SQL Server 2025 AI相关能力初探
· 为什么 退出登录 或 修改密码 无法使 token 失效
more_horiz
keyboard_arrow_up dark_mode palette
选择主题
menu
点击右上角即可分享
微信分享提示