「字符串算法」第3章 KMP 算法课堂过关
「字符串算法」第3章 KMP 算法课堂过关
关于KMP
upd on 2021/4/1 优化一些细节
声明:本文的字符串下标均从1开始,对于某个字符串a,a.substr(i,j)表示a从第i位开始,长度为j的字串
模板题
KMP算法的大致原理
个人认为其他博客已经讲得很好,这里简单讲,把重点放在next数组上
先推几篇博客:
首先,我们把模板题中的\(s_1\)串称为文本串,重命名为\(s\),\(s_2\)称为模式串,重命名为\(t\)(本文中不区分s与t的大小写)
设\(n\)为\(s\)的长度,\(m\)为\(t\)的长度(会在代码片中出现)
看图,在第一轮匹配中,匹配到了一个不相等的位置,如果用暴力,那就是从头再匹配,但是可以看到\(t\)串中有一段重复的“ABC”,无需重复匹配,所以第二轮直接跳到如图所示的位置比较两个蓝色的部分
这就是KMP算法的大致思路
next数组
定义
看了KMP的大致原理,相信大家都产生了疑问:我怎么知道要让T串跳到哪个位置呢?这就要用到next数组了,这是KMP的核心,也是难点
先不用管怎么求next数组,看定义(我自己写的):
令\(j=next_i\),则有\(j<i\)且\(t.substr(1,j)==t.substr(i-j+1,i)\),且对于任意\(k(j<k<i)\),\(t.substr(1,k)≠t.substr(i-k+1,k)\)
也就是说,
next[i]
表示“T中以i结尾的非前缀字串”与“T的前缀”能匹配的最长长度,当不存在这样的j时,next[i]=0
举个例子:
若T="ABCDABCE",则对应的
next={0 0 0 0 1 2 3 0}
应用
根据next数组的定义,next中存储的是长度,但是由于它是T的某个前缀字串的长度,我们也可以将next当做下标使用(一定要弄清楚,不然后面很蒙)
仍然用上面的图片
真懒呐设S的指针为i,T的指针为j,表示当前完成匹配的位置(也就是说S[i]和T[j]是相等的)
第一轮匹配中,当\(j==7\)时,我们发现\(t\)的下一位和\(s\)的下一位不等,但是\(t\)的第57位和13位是一样的,即
next[7]=3
,所以我们需要将\(t\)的指针(j)跳到第3位,也就是j=next[j]
,这里有一些细节不是很好理解,KMP在实现时是很巧妙的,我们放到整段代码理解:while(j != 0 && s[i] != t[j+1]) j = next[j]; if(s[i] == t[j+1]) j++; if(j == m){//j==m标志着已经全部完成匹配 printf("%d\n",i - m + 1); j = next[j]; }
求法
这里是整个KMP最难理解的部分,所以放到最后
先贴出代码
next[1] = 0;//初始化 for(int i = 2 , j = 0 ; i <= m ; i++){ while(j != 0 && t[j+1] != t[i]) j=next[j];//全算法最confusing的语句 if(t[j+1] == t[i]) j++; next[i] = j; }
考虑暴力枚举:最外层循环枚举每一位\(i\),第二层枚举
next[i]
,里层判断第二层枚举的是否合法显然,时间复杂度是在\(O(n^2)~O(n^3)\),还不如\(O(n\cdot m)\)的暴力匹配
优化求法:
先提前声明:求
next[i]
是要用到next[1~i-1]
的,所以我们要从前向后顺序枚举i定义“候选项”的概念(可能跟《算法竞赛……》的不大一样):如果j满足
t.substr(1,j)==t.substr(i-1-j-1,j)&&j<i-1
则j是next[i]
的一个候选项例子:
绿色表示相等的两个字串,则j是
next[i]
的一个候选项,若标成蓝色的两个字符相等,则候选项j是合法的,next[i]
就是所有合法的\(j\)中的最大值+1很显然,对于
next[i]
而言,next[i-1]
是它的候选项,但是,问题是next[next[i-1]],next[next[next[i-1]]],......
都是候选项,为什么呢?还是看图:
假设
next[13]=5
,根据\(next\)的定义,标绿色部分是相等的,再细化一下绿色部分中相等的部分:假设next[5]=2
,同理,第二行(不计最上面的下标行)的黄色部分相等,又因为绿色部分相等,我们可以得到第三行的黄色部分都是相等的,再简化为第4行,会发现:这不是和第一行一样了吗(只是长度小了)!以此类推,可以得到
next[i-1],next[next[i-1]],next[next[next[i-1]]],......
都是候选项,且他们的值是从左向右递减的,因此,按照这个顺序找到第一个合法的候选值之后,我们就可以确定next[i]
了重新看一下代码:
next[1] = 0; for(int i = 2 , j = 0 ; i <= m ; i++){ while(j != 0 && t[j+1] != t[i])//找到第一个合法的候选项 j=next[j];//缩小长度 if(t[j+1] == t[i]) j++; next[i] = j; }
发现,每一轮循环没有
j=next[i-1]
的语句。原因很简单:上一轮结束时语句next[i]=j
决定了这一轮刚开始就有j==next[i-1]
,注意这里的前后的\(i\)不一样(都不是同一轮循环了)不要学傻了时间复杂度
上结论:\(O(n+m)\)
以\(next\)数组的求值为例:
next[1] = 0; for(int i = 2 , j = 0 ; i <= m ; i++){ while(j != 0 && t[j+1] != t[i]) j=next[j]; if(t[j+1] == t[i]) j++; next[i] = j; }
最外层显然是\(O(m)\)的,问题是里面
在
while
循环中,\(j\)是递减的,但是又不会变成负数,所以整个过程中,\(j\)的减小幅度不会超过\(j\)增加的幅度,而\(j\)每次才增加1,最多增加\(m\)次,故\(j\)的总变化次数不超过\(2m\),整个时间复杂度近似认为是\(O(m)\)如果还不能理解,就想像一个平面直角坐标系,\(x\)轴为\(i\),\(y\)轴为\(j\),从原点出发,\(i\)每向右一个单位,\(j\)最多向上一个单位,\(j\)也可以往下掉(
while
循环),但不能掉到第四象限,\(j\)向下掉的高度之和就是while
内语句执行的总次数,是绝对不会超过\(m\)的匹配的循环与上述相近,时间为\(O(n+m)\),不再赘述
所以,总的时间复杂度为\(O(n+m)\)
模板题代码
不要问模板题输出的最后一行是什么意思,
我也不知道,反正输出\(next\)数组就对了#include <iostream> #include <cstdio> #include <cstring> #define nn 1000010 using namespace std; int sread(char s[]) { int siz = 1; do s[siz] = getchar(); while(s[siz] < 'A' || s[siz] > 'Z'); while(s[siz] >= 'A' && s[siz] <= 'Z') { ++siz; s[siz] = getchar(); } --siz; return siz; } char s[nn]; char t[nn]; int next[nn]; int n , m; int main() { n = sread(s); m = sread(t); next[1] = 0; for(int i = 2 , j = 0 ; i <= m ; i++){ while(j != 0 && t[j+1] != t[i]) j=next[j]; if(t[j+1] == t[i]) j++; next[i] = j; } for(int i = 1 , j = 0 ; i <=n ; i++){ while(j != 0 && s[i] != t[j+1]) j = next[j]; if(s[i] == t[j+1]) j++; if(j == m){ printf("%d\n",i - m + 1); j = next[j]; } } for(int i = 1 ; i <= m ; i++) printf("%d " , next[i]); return 0; }
A. 【例题1】子串查找
题目
代码
#include <iostream>
#include <cstdio>
#define nn 1000010
using namespace std;
int ssiz , tsiz;
int sread(char *s) {
int siz = 0;
char c = getchar();
while(!((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')))
c = getchar();
while((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z'))
s[++siz] = c , c = getchar();
return siz;
}
char s[nn] , t[nn];
int nxt[nn];
int main() {
ssiz = sread(s);
tsiz = sread(t);
nxt[1] = 0;
for(int i = 2 ; i <= tsiz ; i++) {
int j = nxt[i - 1];
while(j != 0 && t[j + 1] != t[i])
j = nxt[j];
nxt[i] = j + 1;
}
int ans = 0;
for(int i = 1 , j = 0 ; i <= ssiz ; i++) {
while(j != 0 && s[i] != t[j + 1])
j = nxt[j];
if(s[i] == t[j + 1])
++j;
if(j == tsiz)
j = nxt[j] , ++ans;
}
cout << ans;
return 0;
}
B. 【例题2】重复子串
题目
题目有误
输入若干行,每行有一个字符串,字符串仅含英文字母。特别的,字符串可能为
.
即一个半角句号,此时输入结束。
第五组数据的字符串包含数字字符,有图为证:
思路
设字符串长度为\(siz\)
Hash
不难想也很好写的一种方法
直接枚举最小周期长度\(i\),显然,\(siz\)一定是\(i\)的倍数,所以,这只需要\(O(\sqrt n)\)的时间复杂度
假设我们已经枚举到\(p\)的因数\(x\),就可以直接用\(O(\frac{siz}{x})\)的时间复杂度验证该子字符串是否是周期,代码如下:
inline bool check(int x) {
ul key = hs[x];
for(int i = x + 1 ; i + x - 1 <= siz ; i += x)
if(hs[i + x - 1] - hs[i - 1] * pw[x] != key)//获取字符串s从下标i开始,长度为x的子串的Hash值 , 判断和key是否相等
return false;
return true;
}
KMP
下面讲好些不太好想的KMP做法
先上结论:
命名输入进来的字符串为\(S\),预处理得到\(S\)的\(nxt\)数组
若\(siz\%(siz-nxt_{siz})==0\),则\(siz-nxt_{siz}\)为\(S\)的最小周期,也就是说,此时答案为\(siz / (siz - nxt_{siz})\)
否则,答案为"1"
献上图解:
代码
Hash
#include <iostream>
#include <cstdio>
#include <cstring>
#define nn 1000010
#define ul unsigned long long
using namespace std;
#define value(_) (_ >= 'A' && _ <= 'Z' ? (1 + _ - 'A') : (_ >= 'a' && _ <= 'z' ? (27 + _ - 'a') : (_ - '0' + 53) ))
const ul p = 131;
ul hs[nn];
ul pw[nn];
int siz;
char c[nn];
int sread(char *s) {
int siz = 0;
char c = getchar();
if(c == '.')return -1;
while(!((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9')))
if((c = getchar()) == '.') return -1;
while((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9'))
s[++siz] = c , c = getchar();
return siz;
}
inline bool check(int x) {
ul key = hs[x];
for(int i = x + 1 ; i + x - 1 <= siz ; i += x)
if(hs[i + x - 1] - hs[i - 1] * pw[x] != key)
return false;
return true;
}
int main() {
pw[0] = 1;
for(int i = 1 ; i <= nn - 5 ; i++)
pw[i] = pw[i - 1] * p;
while((siz = sread(c)) != -1) {
for(int i = 1 ; i <= siz ; i++)
hs[i] = hs[i - 1] * p + value(c[i]);
int ans = 0;
for(int i = 1 ; i * i <= siz ; i++) {
if(siz % i == 0)
if(check(i)) {
ans = i;
break;
}
else {
if(check(siz / i))
ans = siz / i;
}
}
printf("%d\n" , siz / ans);
memset(c, 0 , sizeof(c));
memset(hs , 0 , sizeof(hs));
}
return 0;
}
KMP
#include <iostream>
#include <cstdio>
#include <cstring>
#define nn 1000010
using namespace std;
int siz;
int sread(char *s) {
int siz = 0;
char c = getchar();
if(c == '.')return -1;
while(!((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9')))
if((c = getchar()) == '.') return -1;
while((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9'))
s[++siz] = c , c = getchar();
return siz;
}
char s[nn];
int nxt[nn];
int main() {
while(true) {
memset(s , 0 , sizeof(s));
memset(nxt , 0 , sizeof(nxt));
siz = sread(s);
if(siz == -1) break;
nxt[1] = 0;
for(int i = 2 , j = 0 ; i <= siz ; i++) {
while(s[i] != s[j + 1] && j != 0)
j = nxt[j];
if(s[i] == s[j + 1])
++j;
nxt[i] = j;
}
if(siz % (siz - nxt[siz]) == 0)
printf("%d\n" , siz / (siz - nxt[siz]));
else
printf("1\n");
}
return 0;
}
C. 【例题3】周期长度和
题目
思路&代码
以前写过,传送门
题目
这题意不是一般人能读懂的,为了读懂题目,我还特意去翻了题解[手动笑哭]
题目大意:
给定一个字符串s
对于\(s\)的每一个前缀子串\(s1\),规定一个字符串\(Q\),\(Q\)满足:\(Q\)是\(s1\)的前缀子串且\(Q\)不等于\(s1\)且\(s1\)是字符串\(Q+Q\)的前缀.设\(siz\)为所有满足条件的\(Q\)中\(Q\)的最大长度(注意这里仅仅针对\(s1\)而不是\(s\),即一个\(siz\)的值对应一个\(s1\))
求出所有\(siz\)的和
不要被这句话误导了:
求给定字符串所有前缀的最大周期长度之和
正确断句:求给定字符串 所有/前缀的最大周期长度/之和
我就想了半天:既然是"最大周期长度",那不是唯一的吗?为什么还要求和呢?
思路
其实这题要AC并不难(看通过率就知道)
看图
要满足\(Q\)是\(s1\)的前缀,则\(Q\)的\(1\)~\(5\)位和\(s1\)的1~5位是一样的,又因为\(s1\)是\(Q+Q\)的前缀,所以又要满足\(s1\)的6~8位和\(Q+Q\)的6~8位一样,即\(s1\)的6~8位和Q的1~3位相等,回到\(s1\),标蓝色的两个位置相等.
回顾下KMP中\(next\)数组的定义:
next[i]
表示对于某个字符串a,"a中长度为next[i]
的前缀子串"与"a中以第i为结尾,长度为next[i]
的非前缀子串"相等,且next[i]
取最大值是不是悟到了什么,是不是感觉这题和\(next\)数组冥冥之中有某种相似之处?
但是,这仅仅只是开始
按照题目的意思,我们要让\(Q\)的长度最大,也就是图中蓝色部分长度最小,但是\(next\)中存的是蓝色部分的最大值,显然,两者相违背,难道我们要改造\(next\)数组吗?明显不行,若\(next\)存储的改为最小值,则原来求\(next\)的方法行不通.考虑换一种思路(一定要对KMP中\(next\)的求法理解透彻,不然下面看不懂,不行的复习一下),我们知道对于
next[i],next[next[i-1]],next[next[next[i]]]...
都能满足"前缀等于以\(i\)结尾的子串"这个条件,且越往后,值越小,所以,我们的目标就定在上面序列中从后往前第一个不为0的\(next\)值极端条件下,暴力跑可以去到\(O(n^2)\),理论上会超时(我没试过)
两种优化:
- 记忆化,时间效率应该是O(n)这里不详细讲,可以去到洛谷题解查看
- 倍增(我第一时间想到并AC的做法):
我们将j=next[j]
这一语句称作"j跳了一次"(感觉怪怪的),将next拓展为2维,next[i][k]
表示结尾为i,j跳了2^k的前缀字符长度(也就是next[i][0]
等价于原来的next[i]
)
借助倍增LCA的思想(没学没关系,现学现用),这里不做赘述,上代码int tmp = i; for(rr int j = siz[i] ; j >= 0 ; --j)//siz[i]是next[i][j]中第一个为0的小标j,注意倒序枚举 if(next[tmp][j] != 0)//如果不为0则跳 tmp = next[tmp][j];
倍增方法在字符串长度去到\(10^6\)时是非常危险的,带个\(\log\)理论是\(2\cdot 10^7\)左右,常数再大那么一丢丢就TLE了,
还好数据比较水,但是作为倍增和KMP的练习做一下也是不错的最后,记得开longlong(不然我就一次AC了)
完整代码
#include <iostream> #include <cmath> #include <cstdio> #define nn 1000010 #define rr register #define ll long long using namespace std; int next[nn][30] ; int siz[nn]; char s[nn]; int n; int main() { // freopen("P3435_3.in" , "r" , stdin); cin >> n; do s[1] = getchar(); while(s[1] < 'a' || s[1] > 'z'); for(rr int i = 2 ; i <= n ; i++) s[i] = getchar(); next[1][0] = 0; for(rr int i = 2 , j = 0 ; i <= n ; i++) { while(j != 0 && s[i] != s[j + 1]) j = next[j][0]; if(s[j + 1] == s[i]) ++j; next[i][0] = j; } rr int k = log(n) / log(2) + 1; for(rr int j = 1 ; j <= k ; j++) for(rr int i = 1 ; i <= n ; i++) { next[i][j] = next[next[i][j - 1]][j - 1]; if(next[i][j] == 0) siz[i] = j; } ll ans = 0; for(rr int i = 1 ; i <= n ; ++i) { int tmp = i; for(rr int j = siz[i] ; j >= 0 ; --j) if(next[tmp][j] != 0) tmp = next[tmp][j]; if(2 * (i - tmp) >= i && tmp != i) ans += (ll)i - tmp; } cout << ans; return 0; }
D. 【例题4】子串拆分
题目
思路
说明,以下思路时间大致复杂度为\(O(n^2 )\),最坏可以去到\(O(n^3)\),但数据较水可以通过,看了书,上面的解法也是\(O(n^2)\),对于\(1\leq |S|\leq 1.5×10^4\)来说已经是很极限了
其实思路很简单,我们直接枚举子串的左右边界\(L,R\),在右边界扩张的同时把新加入的字符的\(nxt\)求出来.至此,我们得到了子串\(c\),和\(c\)的\(nxt\)数组,时间复杂度为\(O(n^2)\)
那么我们如何判断\(c\)是否符合\(c=A+B+C(k\le len(A),1\le len(B) )\)呢?看代码(其实做了B,C题这里很好理解)
int p = nxt[m];//m为c数组的长度,p即是可能的A的长度
while(p >= k && p > 0) {
if(m - p - p >= 1) {
++ans;
break;//直接退出,优化
}
p = nxt[p];
}
这个判断的复杂度是可以达到\(O(n)\)的,在数据范围下十分危险
下面看下书中是怎么说的:
我还以为书里有严格\(O(n^2)\)的做法
下面\(p_i\)和\(nxt_i\)意义相同
考虑没枚举左端点,假设左端点为\(l\),\(A=S[l,|S|]\),那么对字符串\(A\)跑一次KMP,在匹配的过程中,设匹配到第\(i\)个位置,那么我们就要考虑当前得出的\(j\),显然\(A[1,j]=A[i-j+1,i]\).如果\(i\le 2\cdot j\),那么令\(j=p_j\),此时\(A[i,j]=A[i-j+1,i]\),\(j\)沿指针\(p\)不断回跳,直到\(2\cdot j<i\).然后判断\(j\)是否大于\(k\),如果是,那么累加答案.
因为每次KMP的复杂度是\(O(n)\),所以总时间复杂度为\(O(n^2)\)
核心代码
//每次KMP匹配 inline void solve(char *a) { p[1] = 0; int n = strlen(a + 1); for(int i = 1 , j = 0 ; i < n ; i++) { while(j && a[j + 1] != a[i + 1]) j = p[j]; if(a[j + 1] == a[i + 1]) ++j; p[i + 1] = j; } for(int i = 1 , j = 0 ; i < n ; i++) { while(j && a[j + 1] != a[i + 1]) j = p[j]; if(a[j + 1] == a[i + 1]) j++; while(j * 2 >= i + 1) j = p[j]; if(j >= k) ++ans; } }
//枚举左端点 int len = strlen(str + 1) - (k << 1); for(int i = 0 ; i < len ; i++) solve(str + i);
代码
#include <iostream>
#include <cstdio>
#include <cstring>
#define nn 15000
using namespace std;
int sread(char *s) {
int siz = 0;
char c = getchar();
if(c == '.')return -1;
while(!((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9')))
if((c = getchar()) == '.') return -1;
while((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9'))
s[++siz] = c , c = getchar();
return siz;
}
int n , k , m;
int ans;
int nxt[nn];
char c[nn] , s[nn];
int main() {
n = sread(s);
cin >> k;
for(int L = 1 ; L <= n ; L++) {
memset(nxt , 0 , sizeof(nxt));
memset(c , 0 , sizeof(c));
for(int i = L ; i <= L + k + k ; i++)
c[i - L + 1] = s[i];
m = k + k;
nxt[1] = 0;
for(int i = 2 ; i <= m ; i++) {
int j = nxt[i - 1];
while(c[j + 1] != c[i] && j != 0)
j = nxt[j];
if(c[j + 1] == c[i]) ++j;
nxt[i] = j;
}
for(int R = L + k + k; R <= n ; R++) {
m = R - L + 1;
c[m] = s[R];
int j = nxt[m - 1];
while(c[j + 1] != c[m] && j != 0)
j = nxt[j];
if(c[j + 1] == c[m]) ++j;
if(m != 1) nxt[m] = j;
int p = nxt[m];
while(p >= k && p > 0) {
if(m - p - p >= 1) {
++ans;
break;
}
p = nxt[p];
}
}
}
cout << ans;
return 0;
}