KMP算法

KMP算法

Knuth-Morris-Pratt字符串查找算法(简称为KMP算法)可在一个字符串S内查找一个词W的出现位置。一个词在不匹配时本身就包含足够的信息来确定下一个匹配可能的开始位置,此算法利用这一特性以避免重新检查先前配对的字符。

这个算法由高德纳和沃恩·普拉特在1974年构思,同年詹姆斯·H·莫里斯也独立地设计出该算法,最终三人于1977年联合发表。

前缀函数的定义

给定一个长度为 \(n\) 的字符串 \(S\),其前缀函数被定义为一个长度为 \(n\) 的数组 \(\pi\)

其中 \(\pi[i]\) 的定义是:

  • 如果长度为 \(i + 1\) 的子串 \(S[:i]\) 有一对相等的真前缀与真后缀:\(S[0:k-1]\)\(S[i - k + 1:i]\),那么 \(\pi[i]\) 就是这个相等的真前缀(或者真后缀)的长度,即 \(\pi[i]=k\)
  • 如果不止有一对相等的,那么 \(\pi[i]\) 就是其中最长的那一对的长度;
  • 如果没有相等的,那么 \(\pi[i]=0\)

简单来说 \(\pi[i]\) 就是:子串 \(S[:i]\)最长的相等真前缀与真后缀的长度。

用数学语言描述如下:

\[\pi[i] = \max^i_{k = 0}(k), \ S[0:k - 1] = S[i - k + 1:i] \]

特别地,规定 \(\pi[0]=0\)

举例

例如,对于字符串 \(abcabcd\),它的前缀函数如下:

\(i\) 0 1 2 3 4 5 6
\(S_i\) \(a\) \(b\) \(c\) \(a\) \(b\) \(c\) \(d\)
\(\pi(i)\) 0 0 0 1 2 3 0

其中,因为 \(S[:0]=a\) 没有真前缀和真后缀,所以,特别地规定,\(\pi[0]=0\)

这里,我们直接给出前缀函数的计算方法:

def prefix_function(s: str):
    n = len(s)
    pi = [0] * n
    for i in range(1, n):
        while j > 0 and s[i] != s[j]:
            j = pi[j - 1]

        if s[i] == s[j]:
            j += 1
        pi[i] = j
    return pi

其中,时间复杂度为:\(O(n)\)

前缀函数的应用

KMP算法

给定一个长度为 \(m\) 的文本 \(text\) 和一个长度为 \(n\) 的字符串 \(pattern\),找到并展示 \(s\)\(t\) 中的所有出现(occurrence)。

利用前面的前缀函数计算方法,我们直接给出KMP算法对其应用:

def kmp(text: str, pattern: str):
    m, n = len(text), len(pattern)
    pi = prefix_function(pattern)
    j = 0
    for i in range(m):
        while j > 0 and text[i] != pattern[j]:
            j = pi[j - 1]

        if text[i] == pattern[j]:
            j += 1

        if j == n:
            return i - n + 1
    return -1

其中,时间复杂度为:\(O(n)\)

统计每个前缀的出现次数

问题1:给定一个长度为 \(n\) 的字符串 \(s\),在问题的第一个变种中我们希望统计每个前缀 \(s[0 \dots i]\) 在同一个字符串的出现次数。

问题2:给定一个长度为 \(n\) 的字符串 \(s\),统计每个前缀 \(s[0 \dots i]\) 在另一个给定字符串 \(t\) 中的出现次数。

def count():
    ans = [0] * (n + 1)
    for i in range(0, n):
        ans[pi[i]] += 1
    for i in range(n - 1, 0, -1):
        ans[pi[i - 1]] += ans[i]
    for i in range(0, n + 1):
        ans[i] += 1

字符串压缩

应用

应用1:Leetcode

题目

28. 找出字符串中第一个匹配项的下标

分析

参考前面的KMP算法实现即可。
构造

代码实现

class Solution:
    def strStr(self, haystack: str, needle: str) -> int:
        m, n = len(needle), len(haystack)
        if m == 0:
            return 0

        _next = [0] * m
        j = 0
        for i in range(1, m):
            while j > 0 and needle[i] != needle[j]:
                j = _next[j - 1]

            if needle[i] == needle[j]:
                j += 1
            _next[i] = j

        j = 0
        for i in range(n):
            while j > 0 and haystack[i] != needle[j]:
                j = _next[j - 1]

            if haystack[i] == needle[j]:
                j += 1

            if j == m:
                return i - m + 1
        return -1

参考:

posted @ 2023-03-21 16:02  LARRY1024  阅读(20)  评论(0编辑  收藏  举报