「字符串算法」第3章 KMP 算法课堂过关

「字符串算法」第3章 KMP 算法课堂过关

关于KMP

upd on 2021/4/1 优化一些细节

声明:本文的字符串下标均从1开始,对于某个字符串a,a.substr(i,j)表示a从第i位开始,长度为j的字串

模板题

传送门

KMP算法的大致原理

个人认为其他博客已经讲得很好,这里简单讲,把重点放在next数组上

先推几篇博客:

首先,我们把模板题中的\(s_1\)串称为文本串,重命名为\(s\)\(s_2\)称为模式串,重命名为\(t\)(本文中不区分s与t的大小写)

\(n\)\(s\)的长度,\(m\)\(t\)的长度(会在代码片中出现)

看图,在第一轮匹配中,匹配到了一个不相等的位置,如果用暴力,那就是从头再匹配,但是可以看到\(t\)串中有一段重复的“ABC”,无需重复匹配,所以第二轮直接跳到如图所示的位置比较两个蓝色的部分

这就是KMP算法的大致思路

next数组

定义

看了KMP的大致原理,相信大家都产生了疑问:我怎么知道要让T串跳到哪个位置呢?这就要用到next数组了,这是KMP的核心,也是难点

先不用管怎么求next数组,看定义(我自己写的):

\(j=next_i\),则有\(j<i\)\(t.substr(1,j)==t.substr(i-j+1,i)\),且对于任意\(k(j<k<i)\),\(t.substr(1,k)≠t.substr(i-k+1,k)\)

也就是说,next[i]表示“T中以i结尾的非前缀字串”与“T的前缀”能匹配的最长长度,当不存在这样的j时,next[i]=0

举个例子:

若T="ABCDABCE",则对应的next={0 0 0 0 1 2 3 0}

应用

根据next数组的定义,next中存储的是长度,但是由于它是T的某个前缀字串的长度,我们也可以将next当做下标使用(一定要弄清楚,不然后面很蒙)

仍然用上面的图片真懒呐

设S的指针为i,T的指针为j,表示当前完成匹配的位置(也就是说S[i]和T[j]是相等的)

第一轮匹配中,当\(j==7\)时,我们发现\(t\)的下一位和\(s\)的下一位不等,但是\(t\)的第57位和13位是一样的,即next[7]=3,所以我们需要将\(t\)的指针(j)跳到第3位,也就是j=next[j],这里有一些细节不是很好理解,KMP在实现时是很巧妙的,我们放到整段代码理解:

		while(j != 0 && s[i] != t[j+1])
			j = next[j];
		if(s[i] == t[j+1])
			j++;
		if(j == m){//j==m标志着已经全部完成匹配
			printf("%d\n",i - m + 1);
			j = next[j];
		}

求法

这里是整个KMP最难理解的部分,所以放到最后

先贴出代码

	next[1] = 0;//初始化
	for(int i = 2 , j = 0 ; i <= m ; i++){
		while(j != 0 && t[j+1] != t[i])
			j=next[j];//全算法最confusing的语句
		if(t[j+1] == t[i])
			j++;
		next[i] = j;
	}

考虑暴力枚举:最外层循环枚举每一位\(i\),第二层枚举next[i],里层判断第二层枚举的是否合法

显然,时间复杂度是在\(O(n^2)~O(n^3)\),还不如\(O(n\cdot m)\)的暴力匹配

优化求法:

先提前声明:求next[i]是要用到next[1~i-1]的,所以我们要从前向后顺序枚举i

定义“候选项”的概念(可能跟《算法竞赛……》的不大一样):如果j满足 t.substr(1,j)==t.substr(i-1-j-1,j)&&j<i-1则j是next[i]的一个候选项

例子:

绿色表示相等的两个字串,则j是next[i]的一个候选项,若标成蓝色的两个字符相等,则候选项j是合法的,next[i]就是所有合法的\(j\)中的最大值+1

很显然,对于next[i]而言,next[i-1]是它的候选项,但是,问题是next[next[i-1]],next[next[next[i-1]]],......都是候选项,为什么呢?还是看图:

假设next[13]=5,根据\(next\)的定义,标绿色部分是相等的,再细化一下绿色部分中相等的部分:假设next[5]=2,同理,第二行(不计最上面的下标行)的黄色部分相等,又因为绿色部分相等,我们可以得到第三行的黄色部分都是相等的,再简化为第4行,会发现:这不是和第一行一样了吗(只是长度小了)!

以此类推,可以得到next[i-1],next[next[i-1]],next[next[next[i-1]]],......都是候选项,且他们的值是从左向右递减的,因此,按照这个顺序找到第一个合法的候选值之后,我们就可以确定next[i]

重新看一下代码:

	next[1] = 0;
	for(int i = 2 , j = 0 ; i <= m ; i++){
		while(j != 0 && t[j+1] != t[i])//找到第一个合法的候选项
			j=next[j];//缩小长度
		if(t[j+1] == t[i])
			j++;
		next[i] = j;
	}

发现,每一轮循环没有j=next[i-1]的语句。原因很简单:上一轮结束时语句next[i]=j决定了这一轮刚开始就有j==next[i-1],注意这里的前后的\(i\)不一样(都不是同一轮循环了)不要学傻了

时间复杂度

上结论:\(O(n+m)\)

\(next\)数组的求值为例:

	next[1] = 0;
	for(int i = 2 , j = 0 ; i <= m ; i++){
		while(j != 0 && t[j+1] != t[i])
			j=next[j];
		if(t[j+1] == t[i])
			j++;
		next[i] = j;
	}

最外层显然是\(O(m)\)的,问题是里面

while循环中,\(j\)是递减的,但是又不会变成负数,所以整个过程中,\(j\)的减小幅度不会超过\(j\)增加的幅度,而\(j\)每次才增加1,最多增加\(m\)次,故\(j\)的总变化次数不超过\(2m\),整个时间复杂度近似认为是\(O(m)\)

如果还不能理解,就想像一个平面直角坐标系,\(x\)轴为\(i\)\(y\)轴为\(j\),从原点出发,\(i\)每向右一个单位,\(j\)最多向上一个单位,\(j\)也可以往下掉(while循环),但不能掉到第四象限,\(j\)向下掉的高度之和就是while内语句执行的总次数,是绝对不会超过\(m\)

匹配的循环与上述相近,时间为\(O(n+m)\),不再赘述

所以,总的时间复杂度为\(O(n+m)\)

模板题代码

不要问模板题输出的最后一行是什么意思,我也不知道,反正输出\(next\)数组就对了

#include <iostream>
#include <cstdio>
#include <cstring>
#define nn 1000010
using namespace std;
int sread(char s[]) {
	int siz = 1;
	do
		s[siz] = getchar();
	while(s[siz] < 'A' || s[siz] > 'Z');
	while(s[siz] >= 'A' && s[siz] <= 'Z') {
		++siz;
		s[siz] = getchar();
	}
	--siz;
	return siz;
}
char s[nn];
char t[nn];
int next[nn];
int n , m;
int main() {
	n = sread(s);
	m = sread(t);
	next[1] = 0;
	for(int i = 2 , j = 0 ; i <= m ; i++){
		while(j != 0 && t[j+1] != t[i])
			j=next[j];
		if(t[j+1] == t[i])
			j++;
		next[i] = j;
	}
	for(int i = 1 , j = 0 ; i <=n ; i++){
		while(j != 0 && s[i] != t[j+1])
			j = next[j];
		if(s[i] == t[j+1])
			j++;
		if(j == m){
			printf("%d\n",i - m + 1);
			j = next[j];
		}
	}
	for(int i = 1 ; i <= m ; i++)
		printf("%d " , next[i]);
	return 0;
}

A. 【例题1】子串查找

题目

代码

#include <iostream>
#include <cstdio>
#define nn 1000010
using namespace std;
int ssiz , tsiz;
int sread(char *s) {
	int siz = 0;
	char c = getchar();
	while(!((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')))
		c = getchar();
		
	while((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z'))
		s[++siz] = c , c = getchar();
	return siz;
}
char s[nn] , t[nn];
int nxt[nn];
int main() {
	ssiz = sread(s);
	tsiz = sread(t);
	
	nxt[1] = 0;
	for(int i = 2 ; i <= tsiz ; i++) {
		int j = nxt[i - 1];
		while(j != 0 && t[j + 1] != t[i])
			j = nxt[j];
		nxt[i] = j + 1;
	}
	
	int ans = 0;
	for(int i = 1 , j = 0 ; i <= ssiz ; i++) {
		while(j != 0 && s[i] != t[j + 1])
			j = nxt[j];
		if(s[i] == t[j + 1])
			++j;
		if(j == tsiz)
			j = nxt[j] , ++ans;
	}
	cout << ans;
	return 0;
}

B. 【例题2】重复子串

题目

题目有误

输入若干行,每行有一个字符串,字符串仅含英文字母。特别的,字符串可能为.即一个半角句号,此时输入结束。

第五组数据的字符串包含数字字符,有图为证:

思路

设字符串长度为\(siz\)

Hash

关于字符串Hash

不难想也很好写的一种方法

直接枚举最小周期长度\(i\),显然,\(siz\)一定是\(i\)的倍数,所以,这只需要\(O(\sqrt n)\)的时间复杂度

假设我们已经枚举到\(p\)的因数\(x\),就可以直接用\(O(\frac{siz}{x})\)的时间复杂度验证该子字符串是否是周期,代码如下:

inline bool check(int x) {
	ul key = hs[x];
	for(int i = x + 1 ; i + x - 1 <= siz ; i += x)
		if(hs[i + x - 1] - hs[i - 1] * pw[x] != key)//获取字符串s从下标i开始,长度为x的子串的Hash值 , 判断和key是否相等
			return false;
	return true;
}

KMP

下面讲好些不太好想的KMP做法

先上结论:
命名输入进来的字符串为\(S\),预处理得到\(S\)\(nxt\)数组
\(siz\%(siz-nxt_{siz})==0\),则\(siz-nxt_{siz}\)\(S\)的最小周期,也就是说,此时答案为\(siz / (siz - nxt_{siz})\)
否则,答案为"1"

献上图解:

代码

Hash

#include <iostream>
#include <cstdio>
#include <cstring>
#define nn 1000010
#define ul unsigned long long
using namespace std;
#define value(_) (_ >= 'A' && _ <= 'Z' ? (1 + _ - 'A') : (_ >= 'a' && _ <= 'z' ? (27 + _ - 'a') : (_ - '0' + 53) ))
const ul p = 131;

ul hs[nn];
ul pw[nn];
int siz;
char c[nn];

int sread(char *s) {
	int siz = 0;
	char c = getchar();
	if(c == '.')return -1;
	while(!((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9')))
		if((c = getchar()) == '.')	return -1;
		
	while((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9'))
		s[++siz] = c , c = getchar();
	return siz;
}
inline bool check(int x) {
	ul key = hs[x];
	for(int i = x + 1 ; i + x - 1 <= siz ; i += x)
		if(hs[i + x - 1] - hs[i - 1] * pw[x] != key)
			return false;
	return true;
}
int main() {
	pw[0] = 1;
	for(int i = 1 ; i <= nn - 5 ; i++)
		pw[i] = pw[i - 1] * p;
	while((siz = sread(c)) != -1) {
		for(int i = 1 ; i <= siz ; i++)
			hs[i] = hs[i - 1] * p + value(c[i]);
		
		int ans = 0;
		for(int i = 1 ; i * i <= siz ; i++) {
			if(siz % i == 0)
				if(check(i)) {
					ans = i;
					break;
				}
				else {
					if(check(siz / i))
						ans = siz / i;
				}
		}
		printf("%d\n" , siz / ans);
		memset(c, 0  , sizeof(c));
		memset(hs , 0 , sizeof(hs));
	}
	return 0;
}

KMP

#include <iostream>
#include <cstdio>
#include <cstring>
#define nn 1000010
using namespace std;
int siz;
int sread(char *s) {
	int siz = 0;
	char c = getchar();
	if(c == '.')return -1;
	while(!((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9')))
		if((c = getchar()) == '.')	return -1;
		
	while((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9'))
		s[++siz] = c , c = getchar();
	return siz;
}
char s[nn];
int nxt[nn];
int main() {
	while(true) {
		memset(s , 0 , sizeof(s));
		memset(nxt , 0 , sizeof(nxt));
		siz = sread(s);
		if(siz == -1)	break;
		nxt[1] = 0;
		for(int i = 2 , j = 0 ; i <= siz ; i++) {
			while(s[i] != s[j + 1] && j != 0)
				j = nxt[j];
			if(s[i] == s[j + 1])
				++j;
			nxt[i] = j;
		}
		if(siz % (siz - nxt[siz]) == 0)
			printf("%d\n" , siz / (siz - nxt[siz]));
		else
			printf("1\n");
	}
	return 0;
}

C. 【例题3】周期长度和

题目

传送门

思路&代码

以前写过,传送门

题目

传送门

这题意不是一般人能读懂的,为了读懂题目,我还特意去翻了题解[手动笑哭]

题目大意:

给定一个字符串s

对于\(s\)的每一个前缀子串\(s1\),规定一个字符串\(Q\),\(Q\)满足:\(Q\)\(s1\)的前缀子串且\(Q\)不等于\(s1\)\(s1\)是字符串\(Q+Q\)的前缀.设\(siz\)为所有满足条件的\(Q\)\(Q\)的最大长度(注意这里仅仅针对\(s1\)而不是\(s\),即一个\(siz\)的值对应一个\(s1\))

求出所有\(siz\)的和

不要被这句话误导了:

求给定字符串所有前缀的最大周期长度之和

正确断句:求给定字符串 所有/前缀的最大周期长度/之和

我就想了半天:既然是"最大周期长度",那不是唯一的吗?为什么还要求和呢?

思路

其实这题要AC并不难(看通过率就知道)

看图

要满足\(Q\)\(s1\)的前缀,则\(Q\)\(1\)~\(5\)位和\(s1\)的1~5位是一样的,又因为\(s1\)\(Q+Q\)的前缀,所以又要满足\(s1\)的6~8位和\(Q+Q\)的6~8位一样,即\(s1\)的6~8位和Q的1~3位相等,回到\(s1\),标蓝色的两个位置相等.

回顾下KMP中\(next\)数组的定义:next[i]表示对于某个字符串a,"a中长度为next[i]的前缀子串"与"a中以第i为结尾,长度为next[i]的非前缀子串"相等,且next[i]取最大值

是不是悟到了什么,是不是感觉这题和\(next\)数组冥冥之中有某种相似之处?

但是,这仅仅只是开始

按照题目的意思,我们要让\(Q\)的长度最大,也就是图中蓝色部分长度最小,但是\(next\)中存的是蓝色部分的最大值,显然,两者相违背,难道我们要改造\(next\)数组吗?明显不行,若\(next\)存储的改为最小值,则原来求\(next\)的方法行不通.考虑换一种思路(一定要对KMP中\(next\)的求法理解透彻,不然下面看不懂,不行的复习一下),我们知道对于next[i],next[next[i-1]],next[next[next[i]]]...都能满足"前缀等于以\(i\)结尾的子串"这个条件,且越往后,值越小,所以,我们的目标就定在上面序列中从后往前第一个不为0的\(next\)

极端条件下,暴力跑可以去到\(O(n^2)\),理论上会超时(我没试过)

两种优化:

  1. 记忆化,时间效率应该是O(n)这里不详细讲,可以去到洛谷题解查看
  2. 倍增(我第一时间想到并AC的做法):
    我们将j=next[j]这一语句称作"j跳了一次"(感觉怪怪的),将next拓展为2维,next[i][k]表示结尾为i,j跳了2^k的前缀字符长度(也就是next[i][0]等价于原来的next[i])
    借助倍增LCA的思想(没学没关系,现学现用),这里不做赘述,上代码
		int tmp = i;
		for(rr int j = siz[i] ; j >= 0 ; --j)//siz[i]是next[i][j]中第一个为0的小标j,注意倒序枚举
			if(next[tmp][j] != 0)//如果不为0则跳
				tmp = next[tmp][j];

倍增方法在字符串长度去到\(10^6\)时是非常危险的,带个\(\log\)理论是\(2\cdot 10^7\)左右,常数再大那么一丢丢就TLE了,还好数据比较水,但是作为倍增和KMP的练习做一下也是不错的

最后,记得开longlong(不然我就一次AC了)

完整代码

#include <iostream>
#include <cmath>
#include <cstdio>
#define nn 1000010
#define rr register
#define ll long long
using namespace std;
int next[nn][30] ;
int siz[nn];
char s[nn];
int n;
int main() {
//	freopen("P3435_3.in" , "r" , stdin);
	cin >> n;
	do
		s[1] = getchar();
	while(s[1] < 'a' || s[1] > 'z');
	for(rr int i = 2 ; i <= n ; i++)
		s[i] = getchar();
	
	next[1][0] = 0;
	for(rr int i = 2 , j = 0 ; i <= n ; i++) {
		while(j != 0 && s[i] != s[j + 1])
			j = next[j][0];
		if(s[j + 1] == s[i])
			++j;
		next[i][0] = j;
	}
	
	rr int k = log(n) / log(2) + 1;
	for(rr int j = 1 ; j <= k ; j++)
		for(rr int i = 1 ; i <= n ; i++) {
			next[i][j] = next[next[i][j - 1]][j - 1];
			if(next[i][j] == 0)
				siz[i] = j;
		}
	ll ans = 0;
	for(rr int i = 1 ; i <= n ; ++i) {
		int tmp = i;
		for(rr int j = siz[i] ; j >= 0 ; --j)
			if(next[tmp][j] != 0)
				tmp = next[tmp][j];
		if(2 * (i - tmp) >= i && tmp != i)
			ans += (ll)i - tmp;
	}
	cout << ans;
	return 0;
} 

D. 【例题4】子串拆分

题目

思路

说明,以下思路时间大致复杂度为\(O(n^2 )\),最坏可以去到\(O(n^3)\),但数据较水可以通过,看了书,上面的解法也是\(O(n^2)\),对于\(1\leq |S|\leq 1.5×10^4\)来说已经是很极限了

其实思路很简单,我们直接枚举子串的左右边界\(L,R\),在右边界扩张的同时把新加入的字符的\(nxt\)求出来.至此,我们得到了子串\(c\),和\(c\)\(nxt\)数组,时间复杂度为\(O(n^2)\)
那么我们如何判断\(c\)是否符合\(c=A+B+C(k\le len(A),1\le len(B) )\)呢?看代码(其实做了B,C题这里很好理解)

			int p = nxt[m];//m为c数组的长度,p即是可能的A的长度
			while(p >= k && p > 0) {
				if(m - p - p >= 1) {
					++ans;
					break;//直接退出,优化
				}
				p = nxt[p];
			}

这个判断的复杂度是可以达到\(O(n)\)的,在数据范围下十分危险

下面看下书中是怎么说的:

我还以为书里有严格\(O(n^2)\)的做法

下面\(p_i\)\(nxt_i\)意义相同

考虑没枚举左端点,假设左端点为\(l\),\(A=S[l,|S|]\),那么对字符串\(A\)跑一次KMP,在匹配的过程中,设匹配到第\(i\)个位置,那么我们就要考虑当前得出的\(j\),显然\(A[1,j]=A[i-j+1,i]\).如果\(i\le 2\cdot j\),那么令\(j=p_j\),此时\(A[i,j]=A[i-j+1,i]\),\(j\)沿指针\(p\)不断回跳,直到\(2\cdot j<i\).然后判断\(j\)是否大于\(k\),如果是,那么累加答案.

因为每次KMP的复杂度是\(O(n)\),所以总时间复杂度为\(O(n^2)\)

核心代码

//每次KMP匹配 
inline void solve(char *a) {
	p[1] = 0;
	int n = strlen(a + 1);
	for(int i = 1 , j = 0 ; i < n ; i++) {
		while(j && a[j + 1] != a[i + 1])
			j = p[j];
		if(a[j + 1] == a[i + 1])
			++j;
		p[i + 1] = j;
	}
	for(int i = 1 , j = 0 ; i < n ; i++) {
		while(j && a[j + 1] != a[i + 1])
			j = p[j];
		if(a[j + 1] == a[i + 1])
			j++;
		while(j * 2 >= i + 1)
			j = p[j];
		if(j >= k)
			++ans;
	}
}
//枚举左端点 
int len = strlen(str + 1) - (k << 1);
for(int i = 0 ; i < len ; i++)
	solve(str + i);

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#define nn 15000
using namespace std;
int sread(char *s) {
	int siz = 0;
	char c = getchar();
	if(c == '.')return -1;
	while(!((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9')))
		if((c = getchar()) == '.')	return -1;
		
	while((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9'))
		s[++siz] = c , c = getchar();
	return siz;
}
int n , k , m;
int ans;
int nxt[nn];
char c[nn] , s[nn];

int main() {
	n = sread(s);
	cin >> k;
	for(int L = 1 ; L <= n ; L++) {
		memset(nxt , 0 , sizeof(nxt));
		memset(c , 0 , sizeof(c));
		for(int i = L ; i <= L + k + k ; i++)
			c[i - L + 1] = s[i];
		m = k + k;
		
		nxt[1] = 0;
		for(int i = 2 ; i <= m ; i++) {
			int j = nxt[i - 1];
			while(c[j + 1] != c[i] && j != 0)
				j = nxt[j];
			if(c[j + 1] == c[i])	++j;
			nxt[i] = j;
		}
		
		for(int R = L + k + k; R <= n ; R++) {
			m = R - L + 1;
			c[m] = s[R];
			
			int j = nxt[m - 1];
			while(c[j + 1] != c[m] && j != 0)
				j = nxt[j];
			if(c[j + 1] == c[m])	++j;
			if(m != 1)	nxt[m] = j;
					
			int p = nxt[m];
			while(p >= k && p > 0) {
				if(m - p - p >= 1) {
					++ans;
					break;
				}
				p = nxt[p];
			}
			
		}
	}
	cout << ans;
	return 0;
}
posted @ 2021-04-02 22:01  追梦人1024  阅读(112)  评论(0编辑  收藏  举报