KMP算法

一. 概述

要解决的问题:字符串匹配问题。

  • 目标串target:"aabaabaafa"
  • 模式串pattern:"aabaaf"

传统算法:

双层for循环遍历目标串target和模式串pattern,判断pattern在target第一次出现的位置。

时间复杂度为:\(O(pattern.size()*target.size())\)=\(O(m*n)\)

KMP算法核心思路:

在对目标进行匹配时,如果\(pattern[j] != target[i]\)

  • 之前的做法:\(j = 0; i=i+1\)
  • 改进后:i不变,我们寻找j前面字符串的最长公共前后缀(假设长度为m),那么令j=m,随后再比较pattern[j]和target[i]。

kmp的核心在于避免对模式串的相同前后缀进行多次比较。

最长公共前后缀:

指的是字符串中相等且最长的前后缀串的长度。例如:

  • a:0。
  • aa:1
  • aab:0
  • aaba:1
  • aabaa:2
  • aabaaf:0

假如匹配到aabaaf的 f 时发现不匹配,那么就看 f 前面的字符串,aabaa的最长公共前后缀为2,因此下次从下标2,也就是 b 开始比较。

二. 求next数组

通过上面的描述,我们发现首先要计算出pattern中每个字符前面的字符串的长公共前后缀,也就是俗称的next数组。

Eg:以上述的pattern为例,它的next数组如下:

a a b a a f
0 1 0 1 2 0

核心思路:

求字符串的最长公共前后缀,实际上也是一个字符串匹配的问题,只不过这一次不是匹配两个字符串,而是匹配字符串的前缀和后缀,难点在于前缀和后缀长度是动态变化的,我们用 j 指向前缀字符串的末尾,用 i 指向后缀字符串的末尾(同时也通过 i 来遍历pattern串):

  1. 如果\([i]==[j]\),那么\(next[i]=j\)
  2. 如果\([i]!=[j]\),那么表示匹配失败,这时采用kmp算法,查看 j 前面的字符串的最长公共前后缀,也就是 next[j-1]的值。令\(j=next[j-1]\),然后再次比较 [i] 和 [j],递归1、2步。

可以看出,这实际上是一个递归的过程,我们为了使用kmp算法,需要求pattern的next数组,而在求pattern的next数组过程中,我们又对字符串[0,...,j]和[i-j,...,i]使用了kmp算法。

代码:

// 求next数组
void get_next(vector<int>& next, string& s){
    // 1. 初始化
    int j = 0;
    next[0] = j;
    for(int i = 1; i < s.size(); i++){
        // 2. 判断前后缀不相等的情况
        // 注意:此处是while而不是if。要保证j>0,避免越界操作
        while(j > 0 && s[i] != s[j]){
            // 递归使用kmp算法
            j = next[j - 1];        
        }
        // 3. 判断前后缀相等的情况
        if(s[i] == s[j]){
            j++; 
        }
        // 4. 赋值 
        next[i] = j;
    }
}

三.代码实现

当我们求出pattern的next数组后,kmp的核心工作实际上已经完成。剩下的代码就是简单的遍历判断了:

// 求next数组
void get_next(vector<int>& next, string& s){
    // 1. 初始化
    int j = 0;
    next[0] = j;
    for(int i = 1; i < s.size(); i++){
        // 2. 判断前后缀不相等的情况
        // 注意:此处是while而不是if。要保证j>0,避免越界操作
        while(j > 0 && s[i] != s[j]){
            // 递归使用kmp算法
            j = next[j - 1];        
        }
        // 3. 判断前后缀相等的情况
        if(s[i] == s[j]){
            j++; 
        }
        // 4. 赋值 
        next[i] = j;
    }
}
int my_strstr(string& target, string& pattern){
    if(pattern.size() == 0){
        return 0;
    }
    // 求模式串的next数组
    vector<int> next(pattern.size(), 0);
    get_next(next, pattern);
    int j = 0;      // j指向模式串
    // 遍历目标串
    for(int i = 0; i < target.size(); i++){
        // 注意这里还是while
        while(j > 0 && target[i] != pattern[j]){
            j = next[j - 1];
        }    
        // 如果等于
        if(target[i] == pattern[j]){
            j++;
        }
        if(j == pattern.size()){
            return i - j + 1;
        }
    }
    return -1;
}
posted @ 2024-03-30 21:38  BinaryPrinter  阅读(11)  评论(0编辑  收藏  举报