「笔记」KMP 算法
写在前面
不是我吹,我是真的刚学会 KMP 啊(
引入
给定字符串 \(s_1,s_2\left(|s_2|\le |s_1|\right)\),求 \(s_2\) 在 \(s_1\) 中的所有出现位置。
\(1\le |s_2|\le |s_1|\le 5\times 10^6\)。
1S,128MB。
朴素的想法是枚举 \(s_2\) 在 \(s_1\) 中的开头位置,暴力枚举判断是否匹配。如果失配,则抛弃当前已匹配的部分,到下一位置再从开头匹配。时间复杂度为 \(O(|s_1||s_2|)\)。
而 KMP 算法可以在 \(O(|s_1| + |s_2|)\) 的时空复杂度内解决上述问题,且常数较小。
定义
\(s[i:j]\):字符串 \(s\) 的子串 \(s_i\cdots s_j\)。
真前/后缀:字符串 \(s\) 的真前缀定义为满足不等于它本身的 \(s\) 的前缀。同理就有了真后缀的定义:满足不等于它本身的 \(s\) 的后缀。
\(\operatorname{border}\):字符串 \(s\) 的 \(\operatorname{border}\) 定义为,满足既是 \(s\) 的真前缀,又是 \(s\) 的真后缀的最长的字符串 \(t\)。
如 \(\texttt{aabaa}\) 的 \(\operatorname{border}\) 为 \(\texttt{aa}\)。
\(\operatorname{fail}\):字符串 \(s\) 的 \(\operatorname{fail}\) 是一个长度为 \(|s|\) 的整数数组,它又被称为 \(s\) 的失配指针。\(\operatorname{fail}_i\) 表示前缀 \(s[1:i]\) 的 \(\operatorname{border}\) 的长度,即:
特别的,若不存在这样的 \(j\),则 \(\operatorname{fail}_i = 0\)。如 \(\texttt{aabaa}\) 的 \(\operatorname{fail} = \{0, 1, 0 , 1, 2\}\)。
原理
在朴素算法中,如果在某一位上失配,则会抛弃当前已匹配的部分,跳到下一个位置再从开头进行匹配。
而 KMP 利用了当前已匹配的部分,使得在下一个位置时不必从开头进行匹配,从而对朴素算法进行了加速。
举个例子,如下图所示:
失配指针
\(s\) 的失配指针 \(\operatorname{fail}\) 可以通过在 \(s\) 上按上述思想匹配自身求得。下述算法中枚举到第 \(i\) 位时即可求得 \(\operatorname{fail}_i\)。
首先显然有 \(\operatorname{fail}_1 = 0\)。设枚举到第 \(i\) 位,考虑已知 \(\operatorname{fail}_1\sim \operatorname{fail}_{i-1}\) 的情况下如何求得 \(\operatorname{fail}_i\)。
设当前匹配部分为 \(s[i-l,i-1]\),即有 \(s[i-l,i-1] = s[1,l]\)。则显然有 \(\operatorname{fail}_{i-1} = l\)。接下来考察 \(s_i = s_{l+1}\) 是否成立。
若成立,则有 \(s[i-l,i] = s[1,l + 1]\),得 \(\operatorname{fail}_i = l + 1\)。
若不成立,一种朴素的想法是减小已匹配长度 \(l\) 并暴力检查,直到找到最大的一个 \(l'<l\),满足 \(s[i-l',i-1] = s[1,l']\) 且 \(s_{i}=s_{l'+1}\),此时 \(\operatorname{fail}_i = l'+1\)。考虑利用已匹配部分的 border 加速上述过程。
引理:满足 \(l'<l\) 且 \(s[i-l',i-1] = s[1,l']\) 的 \(l'\) 的最大的 \(l'\) 是 \(\operatorname{fail}_{l}\)。
证明:考虑反证法,设存在 \(j\) 满足 \(\operatorname{fail}_{l}<j<l\) 是最大的满足条件的 \(l'\)。
根据条件,有 \(s[i-j,i-1] = s[1,j]\),又 \(j<l\),则 \(s[i-j,i-1]\) 是 \(s[i-l,i-1]\) 的一段后缀, \(s[1,j]\) 是 \(s[1,l]\) 的一段前缀。则有 \(s[1,j] = s[l - j, l]\) 成立。
又 \(j > \operatorname{fail}_{l}\),根据 border 的定义,则 \(\operatorname{fail}_{l}\) 应为 \(j\),这与已知矛盾,反证原结论成立。直观的理解如下所示:
\[\large \overbrace{\underbrace{s_1 ~ s_2}_{\operatorname{fail}_{l}} ~ s_3 ~ s_4}^{l =\operatorname{fail}_{i-1}} ~ \cdots ~ \overbrace{s_{i-3} ~ s_{i-2} ~ \underbrace{s_{i-1} ~ s_{i}}_{\operatorname{fail}_{l}}}^{l = \operatorname{fail}_{i-1}} ~ s_{i+1} \]
若 \(l'=\operatorname{fail}_{l}\) 仍不满足 \(s_{i}=s_{l'+1}\),则一直令 \(l' = \operatorname{fail}_{l'}\),直到满足条件或 \(l' = 0\)。
模拟上述过程,可以得到下述代码:
fail[1] = 0;
for (int i = 2, j = 0; i <= n2; ++ i) { //j 为匹配长度
while (j > 0 && s2[i] != s2[j + 1]) j = fail[j]; //找到满足条件的 border
if (s2[i] == s2[j + 1]) ++ j; //匹配成功
fail[i] = j;
}
匹配
按照上述过程实现即可,代码如下:
for (int i = 1, j = 0; i <= n1; ++ i) { //j 为匹配长度
while (j > 0 && (j == n2 || s1[i] != s2[j + 1])) j = fail[j]; //找到满足条件的 border,注意当整个串匹配成功的特判。
if (s1[i] == s2[j + 1]) ++ j; //第 j 位匹配成功
if (j == n2) printf("%d\n", i - n2 + 1); //整个串匹配成功
}
完整代码
//知识点:KMP
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#define LL long long
const int kN = 1e6 + 10;
//=============================================================
char s1[kN], s2[kN];
int n1, n2;
int fail[kN];
//=============================================================
inline int read() {
int f = 1, w = 0;
char ch = getchar();
for (; !isdigit(ch); ch = getchar())
if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
return f * w;
}
void Chkmax(int &fir, int sec) {
if (sec > fir) fir = sec;
}
void Chkmin(int &fir, int sec) {
if (sec < fir) fir = sec;
}
//=============================================================
int main() {
scanf("%s", s1 + 1);
scanf("%s", s2 + 1);
n1 = strlen(s1 + 1), n2 = strlen(s2 + 1);
fail[1] = 0;
for (int i = 2, j = 0; i <= n2; ++ i) {
while (j > 0 && s2[i] != s2[j + 1]) j = fail[j];
if (s2[i] == s2[j + 1]) ++ j;
fail[i] = j;
}
for (int i = 1, j = 0; i <= n1; ++ i) {
while (j > 0 && (j == n2 || s1[i] != s2[j + 1])) j = fail[j];
if (s1[i] == s2[j + 1]) ++ j;
if (j == n2) printf("%d\n", i - n2 + 1);
}
for (int i = 1; i <= n2; ++ i) printf("%d ", fail[i]);
return 0;
}
复杂度
求失配指针与匹配两部分的代码类似,仅解释其中一部分。
for (int i = 2, j = 0; i <= n2; ++ i) {
while (j > 0 && s2[i] != s2[j + 1]) j = fail[j];
if (s2[i] == s2[j + 1]) ++ j;
fail[i] = j;
}
代码中仅有 while
的执行次数是不明确的。但可以发现,在 while
中 \(j\) 每次至少减少 1,每层循环中 \(j\) 每次至多增加 1。
又时刻保证 \(j\ge 0\),则 \(j\) 的减少量不大于 \(j\) 的增加量,即 \(n_2\)。故 while
最多执行 \(n_2\) 次,则整个循环的复杂度为 \(O(n)\) 级别。
例题
CF126B Password
给定一字符串 \(s\),求一个字符串 \(t\),满足 \(t\) 既是 \(s\) 的前缀,又是 \(s\) 的后缀,同时 \(t\) 还在 \(s\) 中间出现过(即不作为 \(s\) 的前后缀出现)。
\(1\le |s|\le 10^6\)。
2S,256MB。
既是 \(s\) 的前缀,又是 \(s\) 的后缀的串可以通过枚举 \(\operatorname{fail}_n\),\(\operatorname{fail}_{\operatorname{fail}_n},\cdots\) 获得。
在 \(s\) 中间出现过的所有 \(s\) 的前缀为 \(s[1:\operatorname{fail}_2]\sim s[1:\operatorname{fail}_{n-1}]\),用桶判断这两部分有无重复元素即可。
代码:A submission。
P4391 [BOI2009]Radio Transmission 无线传输
给定一字符串 \(s_1\),已知它是由某个字符串 \(s_2\) 不断自我连接形成的,即有:
\[s_1 = s_2 + s_2 + \cdots+s_2[1,|s_1|\bmod |s_2|] \]求字符串 \(s_2\) 的最短长度。
\(1\le |s_1|\le 10^6\)。
1S,128MB。
考虑一个更简单的问题,如何判断 \(s_1\) 的一个前缀 \(i\) 是否为 \(s_1\) 的循环节?
考虑求 \(s_1\) 的 \(\operatorname{fail}\),显然当 \(i\mid |s_1|\) 且 \(\operatorname{fail}_{|s_1|} = |s_1| - i\) 时 \(i\) 为循环节。
正确性显然,若该条件成立,则保证了 \(s[1:i] = s[i+1:2i], s[i+1:2i] = s[2i+1:3i],\cdots\) 如下所示:
发现呈现错位相等的关系,对应的,则有 \(s[1:i] = s[|s_1| - i+1, |s_1|]\),可得 \(i\) 是一个循环节。
由上,可以得到两种做法。
第一种是暴力枚举前缀 \(i\),判断 \(\operatorname{fail}_{n - (n\bmod i)}\) 是否等于 \(n - (n\bmod i) - i\),且 \(\operatorname{fail}_n\ge i\)。
第一个条件保证了 \(i\) 是 \(s_1[1:n - (n\bmod i)]\) 部分的循环节,第二个条件保证了剩下的部分是 \(i\) 的一个前缀。
第二种是直接输出 \(n - \operatorname{fail}_n\)。原理如下所示:
显然可知最后的不完整部分是 \(n - \operatorname{fail}_n\) 的一个前缀。又保证了 \(\operatorname{fail}_n\) 是最长的既是 \(s_1\) 的前缀又是 \(s_1\) 的后缀的字符串,则 \(n-\operatorname{fail}_n\) 即为答案。
总复杂度均为 \(O(|s_1|)\) 级别。
//知识点:KMP
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#define LL long long
const int kN = 1e6 + 10;
//=============================================================
char s[kN];
int n, fail[kN];
//=============================================================
inline int read() {
int f = 1, w = 0;
char ch = getchar();
for (; !isdigit(ch); ch = getchar())
if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
return f * w;
}
//=============================================================
int main() {
n = read();
scanf("%s", s + 1);
fail[1] = 0;
for (int i = 2, j = 0; i <= n; ++ i) {
while (j && s[i] != s[j + 1]) j = fail[j];
if (s[i] == s[j + 1]) ++ j;
fail[i] = j;
}
//Sol 2:
printf("%d\n", n - fail[n]);
return 0;
//Sol 1:
for (int i = 1; i <= n; ++ i) {
int lth = n - (n % i);
if (fail[lth] == lth - i && fail[n] >= n % i) {
printf("%d\n", i);
return 0;
}
}
return 0;
}
「NOI2014」动物园
\(n\) 组数据,每次给定一字符串 \(s\)。
定义 \(\operatorname{num}_i\) 表示是 \(s[1:i]\) 的前后缀,且长度不大于 \(\left\lfloor\frac{i}{2}\right\rfloor\) 的字符串的个数。
求:\[\prod_{i=1}^{n}\left(\operatorname{num}_i + 1\right)\pmod {10^9 + 7} \]\(1\le n\le 5\),\(1\le |s|\le 10^6\)。
1S,512MB。
做法是自己 YY 的,效率被爆踩但是能过(
记 \(\mathbf{B}(i)\) 表示满足既是前缀 \(s[1:i]\) 的真前缀,又是其真后缀的字符串组成的集合。
先不考虑长度不大于 \(\left\lfloor\frac{i}{2}\right\rfloor\) 这一限制,对于前缀 \(s[1:i]\),显然 \(\operatorname{num}_i\) 的值为 \(|\mathbf{B}(i)|\)。则显然有 \(\operatorname{num}_{i} = \operatorname{num}_{\operatorname{fail}_i} + 1\),表示在 \(\operatorname{num}_i\) 的基础上计入 \(s[1:i]\) 的 \(\operatorname{border}\) 的贡献。\(\operatorname{num}\) 可在 KMP 算法中顺便求得。
再考虑限制,若前缀 \(s[1:i]\) 的 \(\operatorname{border}\) 的长度大于 \(\left\lfloor\frac{i}{2}\right\rfloor\),则需要不断跳 \(\operatorname{fail}\),跳到第一个满足长度合法的位置 \(j\in \mathbf{B}(i)\),再统计其贡献 \(\operatorname{num}_j\)。
暴跳实现可以获得 50pts 的好成绩。
发现跳 \(\operatorname{fail}\) 过程中对应的字符串长度会缩短(废话),考虑倒序枚举各位置 \(i\),使得 \(\left\lfloor\frac{i}{2}\right\rfloor\) 也呈现递减的状态。
考虑暴跳过程,显然是由于某些 \(\operatorname{fail}\) 的转移被重复统计,导致暴跳效率较低。考虑并查集的思路,将重复的转移进行路径压缩。
设 \(\operatorname{pos}_{i}\) 表示前缀 \(s[1:i]\) 在跳 \(\operatorname{fail}\) 之后对应的最大的第一个满足长度合法的 \(\mathbf{B}\) 中的元素,初始值为 \(\operatorname{pos}_i = i\)。在暴力跳 \(\operatorname{fail}\) 时,更新沿途遍历到的 \(\operatorname{pos}\) 即可。
这个路径压缩的复杂度我并不会证,但是感觉跑的还蛮快的= =
//知识点:KMP
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#define LL long long
const int kN = 1e6 + 10;
const int mod = 1e9 + 7;
//=============================================================
int n, ans, next[kN], num[kN], pos[kN];
char s[kN];
//=============================================================
inline int read() {
int f = 1, w = 0;
char ch = getchar();
for (; !isdigit(ch); ch = getchar())
if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
return f * w;
}
void Init() {
ans = 1;
scanf("%s", s + 1);
n = strlen(s + 1);
}
void KMP() {
for (int i = 2, j = 0; i <= n; ++ i) {
while (j > 0 && s[i] != s[j + 1]) j = next[j];
if (s[i] == s[j + 1]) ++ j;
next[i] = j;
if (! j) continue ;
pos[j] = j; //初始化
num[j] = 1ll * (num[next[j]] + 1ll) % mod;
}
}
int Find(int x_, int lth_) {
if (pos[x_] <= lth_ / 2) return pos[x_];
return pos[x_] = Find(next[pos[x_]], lth_); //路径压缩
}
//=============================================================
int main() {
int t = read();
while (t --) {
Init(); KMP();
for (int i = n; i >= 2; -- i) {
pos[next[i]] = Find(next[i], i); //找到贡献位置
if (! pos[next[i]]) continue ; //特判无贡献情况
ans = 1ll * ans * (num[pos[next[i]]] + 1) % mod;
}
printf("%d\n", ans);
}
return 0;
}
写在最后
参考资料: