学习笔记:KMP
因为之前学过 KMP 忘了所以写了学习笔记,防止自己再忘。
引入
KMP是一种字符串匹配算法,可以在将近线性的时间复杂度内进行字符串匹配。
此类问题通常有一个文本串 \(S\) 和一个模式串 \(P\) 构成,说白了就是在 \(S\) 中匹配 \(T\),S.find(T)
。
不难打出一个暴力程序,原理是枚举 \(S\) 中开始时匹配的位置,逐位匹配字符。设 \(|S|=n,|T|=m\),则暴力的时间复杂度为 \(O(nm)\),代码如下。
int Search(string S,string T)
{
int n=S.size(),m=T.size();
for(int i=0;i<n-m+1;i++)
{
bool fl=1;
for(int j=0;j<m;j++)
{
if(S[i+j]!=T[j])
{
fl=0;
break;
}
}
if(fl)
return i;
}
return -1;
}//返回第一次匹配成功在S中的下标
例如 \(S\)=ABABC
,\(P\)=ABC
。
下文中用 Y
代表匹配成功,N
代表匹配失败。
第一次匹配:
ABABC
ABC
-----
YYN
第二次匹配:
ABABC
ABC
-----
N
第三次匹配:
ABABC
ABC
-----
YYY
我们发现我们在中途进行了一些无用的匹配(如第二次匹配),KMP 算法则能将这些无用的匹配省去。
思想
我们定义一个数组 \(nxt\),\(nxt[i]=j\) 表示在模式串 \(P\) 中截取前 \(i-1\) 位字符作为一个字符串。在这个新字符串中前 \(j\) 位与后 \(j\) 位相等。(\(j<i-1\),\(j\) 在所有值中取最大值)。特别的,\(nxt[1]=-1\)。
当 \(P\)=AABAABD
时,\(nxt\) 数组的值如下表。
真前缀的意思是不为本身的最长前缀,真后缀同理。
下标 | 对应字符串 | \(nxt[i]\) | 最长相同真前后缀 |
---|---|---|---|
1 | |
-1 | |
2 | A |
0 | |
3 | AA |
1 | A |
4 | AAB |
0 | |
5 | AABA |
1 | A |
6 | AABAA |
2 | AA |
7 | AABAAB |
3 | AAB |
8 | AABAABD |
0 | |
当文本串 \(S\)=AABAABAABD
时,我们进行第一次匹配:
AABAABAABD
AABAABD
--------------
YYYYYYN
不难发现,我们第二次匹配时,我们可以直接这样匹配:
AABAABAABD
AABAABD
--------------
YYYYYYY
因为在第一次进行匹配时,我们发现文本串中的
AABAABAABD
|||
与模式串中的
AABAABD
|||
是相同的,所以我们可以直接从文本串中的第 \(7\) 位与模式串中的第 \(4\) 位开始匹配。
观察上面可以得出KMP的核心思想。既通过寻找上次已经匹配过的部分中无需再次匹配的部分来节省时间。
用另一种方式说,就是在一个字符串中寻找它的一个子串 \(a\),使得这个字符串满足以下形式:
a...a
或
aa
并且要使 \(a\) 长度尽可能的长。
实现
P3375 【模板】KMP字符串匹配
代码如下
#include<iostream>
#include<cstring>
using namespace std;
#define N 2000010
char s[N],t[N];
int len1,len2,nxt[N];
int main()
{
scanf("%s%s",s,t);
nxt[0]=nxt[1]=0;
len1=strlen(s);
len2=strlen(t);
for(int i=1,j=0;i<len2;i++)
{
while(j&&t[i]!=t[j])
j=nxt[j];
if(t[i]==t[j])
{
j++;
nxt[i+1]=j;
}
}//求模式串的next数组
for(int i=0,j=0;i<len1;i++)
{
while(j&&s[i]!=t[j])
j=nxt[j];
if(s[i]==t[j])
j++;
if(j==len2)
printf("%d\n",i-len2+2);
}//进行匹配
for(int i=1;i<=len2;i++)
printf("%d ",nxt[i]);
printf("\n");
}
1-Index Ver. & fail
for(int i=2,j=0;i<=n;i++)
{
while(j>0&&s[i]!=s[j+1])
j=nxt[j];
if(s[i]==s[j+1])
j++;
nxt[i]=j;
}
if(s[1]!=s[2])
fail[1]=1;
for(int i=2;i<=n;i++)
{
int j=nxt[i];
while(j>0&&s[j+1]==s[i+1])
j=nxt[j];
if(s[i+1]!=s[j+1])
j++;
fail[i]=j;
}