kmp算法详解
引入
kmp算法要解决的就是用on的时间复杂度模式串p在文本串T中的匹配问题
过程
字符串下标从1开始
对于文本串T(上)和模式串p(下)T.size()=n , p.size()=m
设T[i]和p[j]为正在接受比对的一对字符
-
如果j<m-1&&T[i+1]==p[j+1],那么i++,j++。
-
如果T[i+1]!=p[j+1],那么我们就找到一个最大的k,使得p[1]...p[k]与T[i-k+1]...T[i]匹配,让j退回到k继续比较,j回退的过程中i是不动的,所以只是遍历了一遍文本串,复杂度on
-
如果回退后还不想等,那么继续后退,直到T[i+1]=p[j+1]或者j==0为止
这样我们要解决的问题就只剩下了如何快速地预处理出k。
next数组
关于快速求解k其实是与文本串T没有关系的。
k的作用是为了让p的前缀与T中以T[i]结尾字符串的后缀匹配上。
但是当我们要让j回退到k的时候,T[i-j+1]...T[i]是等于p[1]...p[j]的,
也就是这个时候p中以p[j]结尾字符串的后缀是等于T中以T[i]结尾字符串的后缀的。
所以问题就变成了让p的前缀与p中以p[j]结尾字符串的后缀匹配上,
这样一看,整个求解模式串p中每个j对应的k的过程中都不需要文本串T的参与
所以我们求解出的k只需要满足p[1]...p[k]与p[j-k+1]...p[j]匹配,我们来用next数组来维护每个j对应的k
遍历每一个j,当next[j]确定后,我们这样计算next[j+1]:
看p[ next[j] + 1]是否等于p[j+1],如果等于,那么next[j+1] = next[j]+1
否则的话设next[j]=k,让k = next[k]这样回退直到匹配或者k=0。当退无可退(k=0),比较p[j+1]与p[1]相等的话next[j+1]=1,否则next[j+1]=0
可以根据下图模拟一下过程
kmp的板子如下
//初始化ne数组
for(int i=2,j=0;i<=m;i++) //i为将要匹配的后缀的最后一个元素的下标,j为已匹配的前缀的最后一个下标
{
while(j && p[i]!=p[j+1]) j = ne[j];
if(p[i] == p[j+1]) j++;
ne[i] = j;
}
//kmp匹配
for(int i=1,j=0;i<=n;i++)
{
while(j && T[i]!=p[j+1]) j = ne[j];
if(T[i] == p[j+1]) j++;
if(j==m) cout<<i-m+1<<endl;
}
例题
1.P3375 【模板】KMP 字符串匹配 https://www.luogu.com.cn/problem/P3375
# include<bits/stdc++.h>
using namespace std;
const int N = 1e6+10;
char T[N],p[N]; //T为文本串数组,p为模式串数组
int ne[N]; //kmp的next数组,next可能会和一些保留字冲突所以改成ne
int main()
{
cin>>T+1>>p+1;
int n = strlen(T+1),m = strlen(p+1);
//初始化ne数组
for(int i=2,j=0;i<=m;i++) //i为将要匹配的后缀的最后一个元素的下标,j为已匹配的前缀的最后一个下标
{
while(j && p[i]!=p[j+1]) j = ne[j];
if(p[i] == p[j+1]) j++;
ne[i] = j;
}
//kmp匹配
for(int i=1,j=0;i<=n;i++)
{
while(j && T[i]!=p[j+1]) j = ne[j];
if(T[i] == p[j+1]) j++;
if(j==m) cout<<i-m+1<<endl;
}
for(int i=1;i<=m;i++) cout<<ne[i]<<" ";
return 0;
}
2. 字符串匹配例题
http://oj.daimayuan.top/course/22/problem/908
给你两个字符串 a,b,字符串均由小写字母组成,现在问你 b 在 a 中出现了几次。
输入有多组数据,第一行为数据组数 T,每组数据包含两行输入,第一行为字符串 a,第二行为字符串 b。
对于每组输入需要输出两行,其中第一行为出现次数,第二行为每次出现时第一个字符在 a 中的下标(字符串首位的下标为 1)。如果找不到,输出两行 −1。
输入格式
第一行一个整数 T。
接下来 2T 行,每行一个字符串。
输出格式
输出 2T 行,每行若干个整数,详见题面。
样例输入
3
abababa
aba
aaaaaa
bb
abcabcab
abcab
样例输出
3
1 3 5
-1
-1
2
1 4
数据规模
对于所有数据,保证 1≤T≤10,1≤|a|,|b|≤105,字符串均由小写字母构成。
# include<bits/stdc++.h>
using namespace std;
const int N = 1e6+10;
char T[N],p[N];
int ne[N];
int main()
{
int t;cin>>t;
while(t--)
{
vector<int> ans;
cin>>T+1>>p+1;
int n = strlen(T+1),m = strlen(p+1);
for(int i=2,j=0;i<=m;i++)
{
while(j && p[j+1]!=p[i]) j = ne[j];
if(p[j+1] == p[i]) j++;
ne[i] = j;
}
for(int i=1,j=0;i<=n;i++)
{
while(j && T[i]!=p[j+1]) j = ne[j];
if(T[i] == p[j+1]) j++;
if(j == m) ans.push_back(i-m+1);
}
if(ans.size())
{
cout<<ans.size()<<endl;
for(auto t:ans) cout<<t<<" ";
cout<<endl;
}
else cout<<"-1"<<endl<<"-1"<<endl;
}
}
3. 最小循环覆盖
http://oj.daimayuan.top/course/22/problem/909
给你一个字符串 a,你需要求出这个字符串的最小循环覆盖的长度。
b 是 a 的最小循环覆盖,当且仅当 a 是通过 b 复制多次并连接后得到的字符串的前缀,且 b 是满足条件的字符串中长度最小的。
输入一个字符串 a,输出一个数表示最小循环覆盖 b 的长度。
输入格式
一行一个字符串 a。
输出格式
输出一行一个数表示答案。
样例输入1
bcabcabcabcab
样例输出1
3
样例输入2
aaaaaaa
样例输出2
1
数据规模
对于所有数据,保证 1≤|a|≤105,字符串均由小写字母构成。
先说结论:
最小覆盖长度 = n-next[n]
证明如下:
字符串a最大匹配(长度为next[n])的前缀和后缀有两种情况
- 前缀和后缀没有公共部分,那么只需要将a-后缀这部分才复制一部分然后删减到a就能得到a,最小覆盖长度也就是n-后缀.size(),即n-next[n]
- 前缀和后缀没有公共部分,那么可以将前缀的一部分与后缀的一部分匹配,然后再将前缀之后的一部分再与后缀之后的一部分匹配,如此往复,那么最后前缀会剩下一部分,也就是n-next[n]就是最小循环覆盖,只不过字母的顺序可能不对。可以拿着bcabcab手玩一下这个结论
# include<bits/stdc++.h>
using namespace std;
const int N = 1e6+10;
int ne[N]; //kmp的next数组,next可能会和一些保留字冲突所以改成ne
char p[N];
int main()
{
cin>>p+1;
int m = strlen(p+1);
//初始化ne数组
for(int i=2,j=0;i<=m;i++) //i为将要匹配的后缀的最后一个元素的下标,j为已匹配的前缀的最后一个下标
{
while(j && p[i]!=p[j+1]) j = ne[j];
if(p[i] == p[j+1]) j++;
ne[i] = j;
}
cout<<m-ne[m];
return 0;
}
4. [UVA 12467] Secret word
http://oj.daimayuan.top/course/22/problem/910
给你一个字符串 s,你需要找出 s 中最长的 secret word,一个字符串 p 是 secret word 需要满足:
p 是 s 的子串(p 可以与 s 相等);
将 p 翻转后是 s 的前缀。
输入一行字符串 s,输出一行字符串为你求得的 p。
输入格式
一行一个字符串 s。
输出格式
输出一行一个字符串表示答案。
样例输入
listentothesilence
样例输出
sil
数据规模
对于所有数据,保证 1≤|s|≤105,字符串均由小写字母构成。
做法:我们可以将字符串s反转为s',将s变为s+'#'+s',然后对着s求一遍kmp。对于s'所在的i中最大的next就是目标字符串的长度。
加入#的目的是防止到了s'这里它的后缀与前缀会有重合的部分。
# include<bits/stdc++.h>
using namespace std;
const int N = 2e5+10;
char s[N];
int ne[N];
int main()
{
string s1;cin>>s1;
string s2 = s1;
reverse(s2.begin(),s2.end());
s1+='#'+s2;
int len = 0;
for(auto i:s1) s[++len] = i;
for(int i=2,j=0;i<=len;i++)
{
while(j && s[i]!=s[j+1]) j = ne[j];
if(s[i] == s[j+1]) j++;
ne[i] = j;
}
int ans = 0;
for(int i=(len-1)/2+2;i<=len;i++) ans = max(ans,ne[i]);
for(int i=ans;i;i--) cout<<s[i];
return 0;
}