扩展KMP算法
一 问题定义
给定母串S和子串T,定义n为母串S的长度,m为子串T的长度,suffix[i]为第i个字符开始的母串S的后缀子串,extend[i]为suffix[i]与字串T的最长公共前缀长度。求出所有的extend[1..n]。容易发现,如果存在某个i,使得extend[i] = m,这便是经典的KMP算法要解决的问题。
二 扩展KMP算法思想
和KMP算法的是想类似,充分利用已经比较字符性质来减少冗余的字符比较次数。KMP的思想是充分的利用模式串中所有前缀字串(以模式串为开头的字串)的真前缀和真后缀(指子串的开始字符与子串的最后字符相等的个数)来减少不必要的字符比较,真前缀和真后缀相等的个数保存在next数组中。扩展KMP算法则是利用子串T中的所有后缀子串suffix[i]与字串T的最长公共前缀来减少字符的比较次数,这个最长公共前缀的个数也记录在next数组中。
假设我们已经求出了extend[0..k]并且记录了a和p,p表示在S串中从第i(i=0..k)个字符开始匹配的过程中达到的最远的位置,a表示取p这个最大值所对应的i值。现在求extend[k+1],可以分下面三种情况:
- k+1==p(p肯定是大于或等于k+1的)
直接S[k+1]与T[0]开始比较,并更新a和p的值 - k + 1 + next[k - a + 1] <= p
extend[k+1] = next[k - a + 1] - k + 1 + next[k - a + 1] > p
直接S[p]与T[p - k -1]开始比较,extend[k+1]也对应的从p - k - 1开始往后加,并更新a和p的值
对于第二种和第三种情况怎么得到的呢?这就要靠我们保存的a和p了。如果p > k + 1,那么必然有S[k+1,p-1] = T[k-a+1,p-a],这个可以通过S[a,p-1] = T[0,p-a]推出。我们在next[k-a+1]中保存了T的suffix[k-a+1]子串与T的最长公共前缀的长度,假设L=next[k-a+1],那么T[k-a+1,k-a+L]=T[0,L-1]。如果k+1+L<= p,通过S[k+1,p-1] = T[k-a+1,p-a],可以得到S[k+1,k+L]=T[k-a+1,k-a+L]=T[0,L-1],并且S[k+L+1]!=T[0,L]的(因为如果相等的话,T[k-a+1,k-a+L+1]=T[0,L],这与前面的T[k-a+1,k-a+L]=T[0,L-1]是矛盾的),所以extend[k+1]=L。如果k+1+L> p,那么S[k+1,p-1] = T[k-a+1,p-a]=T[0,p-k-2],所以extend[k+1]至少等于p-k-1,然后我们应该继续从p开始比较,因为从p开始都是没有比较过的。
求next的思路和求extend的思路是一样的,只不过是T对自己求extend。
三 扩展KMP算法实现
#define MAXSIZE 400005 int extend[MAXSIZE], next[MAXSIZE]; void get_next(char *t, int n) { int a, p, k, i; next[0] = 0; a = 0; p = 1; for (k = 1; k < n; k++) { if (k == p) { i = k; while (t[i] == t[i - k]) i++; next[k] = i - k; if (i == p) p++; else p = i; a = k; } else if (k + next[k - a] <= p) next[k] = next[k - a]; else { next[k] = p - k; i = p; while (t[i] == t[i - k]) { i++; next[k]++; } if (i > p) { p = i; a = k; } } } } void get_extend(char *s, int sn, char *t, int tn) { int a, p, k, i; p = a = 0; for (k = 0; k < sn; k++) { if (k == p) { i = k; while (i - k < tn && s[i] == t[i - k]) i++; extend[k] = i - k; if (i == p) p++; else p = i; a = k; } else if (k + next[k - a] <= p) extend[k] = next[k - a]; else { extend[k] = p - k; i = p; while (i - k < tn && s[i] == t[i - k]) { i++; extend[k]++; } if (i > p) { p = i; a = k; } } } }