hdu6153后缀数组或扩展KMP
前两天刷了几题leetcode,感觉挺简单,于是又想刷刷hduoj了。随便打开没做过的一页,找了一题通过人数最多的,就是这道6153.
①.看完题没想太多,觉得应该是后缀数组(多年没刷题的我字符串这一块对后缀数组记忆最深吧),因为S1和S2长度都一百万,n^2受不了。nlogn应该行。
②.用后缀数组的话,需要会用后缀数组求子串出现次数。如果是任意子串,还不太好办,但是这里的子串只会是后缀,那就好办了——只需看此后缀在所有后缀中的前缀的数量,也即与本后缀的LCP等于len(s)的后缀数量。因为后缀数组是排过序的,只要往后看,直到height值小于len(s)为止。
③.会了②,就好办了。把S1和S2连起来(中间插入一个比如~号,记为S),求出S的后缀数组SA;又对S2求出后缀数组SA2。
对于每一个S2的后缀,用在SA中求出的出现次数减去SA2中求出的出现次数,就是此后缀在S1中的出现次数(因为它不会在跨S1和S2时出现)。
④.复杂度分析。构造后缀数组最优算法是O(n),用倍增算法构造后缀数组是O(nlogn)。后面的计算步骤,因为S2长度n为一百万,要计算n次,每次用②中的办法往后找,最坏情况要O(n)。但是不意味着这里的总复杂度只能O(n^2)。因为后缀数组是排好序的,按后缀数组的逆序处理这n次查找,那么第k次查找可以充分利用第k-1次查找的结果,往后滑动。也就是说这n次查找基本不会重叠,所以最后查找计算部分的总复杂度仍为O(n)。
这个地方我举个例子,以字符串aabaaaab~aabaaaab为例,它有17个后缀,在后缀数组中的顺序为:
aaaab
aaaab~aabaaaab
aaab
aaab~aabaaaab
aab
aabaaaab
aabaaaab~aabaaaab
aab~aabaaaab
ab
abaaaab
abaaaab~aabaaaab
ab~aabaaaab
b
baaaab
baaaab~aabaaaab
b~aabaaaab
~aabaaaab
求出后缀baaaab的出现次数为2以后,再求后缀b的出现次数时,因为知道b是baaaab的子串(height值等于b的长度),可以直接滑到b~aabaaaab进行判断。
那最后整个算法的复杂度还是卡在构造后缀数组部分,如果用倍增,那整个算法的复杂度为O(nlogn)
⑤.具体实现上,因为我要复用两次后缀数组的代码,所以每次预处理出data数组,再处理出d数组。
d数组就是逆序存的每个后缀在整个串中出现的次数。
代码如下:
1 /* 2 * Author : ben 3 */ 4 #include <cstdio> 5 #include <cstdlib> 6 #include <cstring> 7 typedef long long LL; 8 const int MAXN = 2010000; 9 char s[MAXN]; 10 int sa[MAXN], height[MAXN], rank[MAXN], N; 11 int tmp[MAXN], top[MAXN]; 12 void makesa() { 13 int i, j, len, na; 14 na = (N < 256 ? 256 : N); 15 memset(top, 0, na * sizeof(int)); 16 for (i = 0; i < N; i++) { 17 top[rank[i] = s[i] & 0xff]++; 18 } 19 for (i = 1; i < na; i++) { 20 top[i] += top[i - 1]; 21 } 22 for (i = 0; i < N; i++) { 23 sa[--top[rank[i]]] = i; 24 } 25 for (len = 1; len < N; len <<= 1) { 26 for (i = 0; i < N; i++) { 27 j = sa[i] - len; 28 if (j < 0) { 29 j += N; 30 } 31 tmp[top[rank[j]]++] = j; 32 } 33 sa[tmp[top[0] = 0]] = j = 0; 34 for (i = 1; i < N; i++) { 35 if (rank[tmp[i]] != rank[tmp[i - 1]] 36 || rank[tmp[i] + len] != rank[tmp[i - 1] + len]) { 37 top[++j] = i; 38 } 39 sa[tmp[i]] = j; 40 } 41 memcpy(rank, sa, N * sizeof(int)); 42 memcpy(sa, tmp, N * sizeof(int)); 43 if (j >= N - 1) { 44 break; 45 } 46 } 47 } 48 49 void lcp() { 50 int i, j, k; 51 for (j = rank[height[i = k = 0] = 0]; i < N - 1; i++, k++) { 52 while (k >= 0 && s[i] != s[sa[j - 1] + k]) { 53 height[j] = (k--), j = rank[sa[j] + 1]; 54 } 55 } 56 } 57 58 char S1[MAXN], S2[MAXN]; 59 int data1[MAXN], data2[MAXN]; 60 int d[MAXN]; 61 62 void makedata(int *data) { 63 data[0] = 1; 64 for (int i = N - 2; i > 0; i--) { 65 int leni = N - 1 - sa[i]; 66 int j = N - i - 1; 67 if (height[i + 1] < leni) { 68 data[j] = 1; 69 } else { 70 int k = i + data[j - 1] + 1; 71 while (k < N && height[k] >= leni) { 72 k++; 73 } 74 data[j] = k - i; 75 } 76 // cout << data[j] << endl; 77 } 78 } 79 80 const LL MOD_NUM = 1000000007LL; 81 int work(int lens1, int lens2) { 82 int ans = 0; 83 strcpy(s, S2); 84 N = lens2 + 1; 85 makesa(); 86 lcp(); 87 makedata(data1); 88 for (int i = 0; i < lens2; i++) { 89 int j = lens2 - i - 1; 90 d[i] = data1[lens2 - rank[j]]; 91 // cout << d[i] << endl; 92 } 93 strcpy(s, S1); 94 s[lens1] = '#'; 95 s[lens1 + 1] = 0; 96 strcat(s, S2); 97 N = lens1 + lens2 + 2; 98 makesa(); 99 lcp(); 100 makedata(data2); 101 for (int i = 0; i < lens2; i++) { 102 int j = N - i - 2; 103 d[i] = data2[N - 1 - rank[j]] - d[i]; 104 ans = ans % MOD_NUM + ((i + 1LL) * d[i]) % MOD_NUM; 105 // cout << d[i] << endl; 106 } 107 return ans % MOD_NUM; 108 } 109 110 int main() { 111 int T; 112 scanf("%d", &T); 113 while (T--) { 114 scanf("%s%s", S1, S2); 115 int lens1 = strlen(S1); 116 int lens2 = strlen(S2); 117 printf("%d\n", work(lens1, lens2)); 118 } 119 return 0; 120 }
然而,提交上去,超时了。看来数据量很强。是卡在后缀数组构造的倍增算法上了。但是我手头没有DC3等更优算法的模板。
转念一想,这题通过的人这么多,不可能需要高阶后缀数组算法的。于是回忆了一下还有别的什么算法。想到了扩展KMP。
用扩展KMP解此题的思路主要是要把S1和S2逆序,逆序以后,题目要求的S2的每个后缀就变成前缀了。根据extend数组的定义我们知道,如果extend[i] = x,则表示S2的前x个字符与S1从i开始的x个字符相同,统计extend数组中有多少个x就知道S2的这个前缀在S1中出现的次数。这种统计是可以在线性时间完成的,而前面生成extend数组的时间也为线性,故最后整体复杂度也为线性O(N)。代码如下
1 /* 2 * Author : ben 3 */ 4 #include <cstdio> 5 #include <cstdlib> 6 #include <cstring> 7 #include <cmath> 8 #include <ctime> 9 #include <algorithm> 10 typedef long long LL; 11 const int MAXN = 1001000; 12 char S1[MAXN], S2[MAXN]; 13 int d[MAXN]; 14 const LL MOD_NUM = 1000000007LL; 15 int next[MAXN], extend[MAXN]; 16 void get_next(const char *str, int len){ 17 // 计算next[0]和next[1] 18 next[0] = len; 19 int i = 0; 20 while(str[i] == str[i + 1] && i + 1 < len) { 21 i++; 22 } 23 next[1] = i; 24 int po = 1; //初始化po的位置 25 for(i = 2; i < len; i++) { 26 if(next[i - po] + i < next[po] + po) { //第一种情况,可以直接得到next[i]的值 27 next[i] = next[i - po]; 28 } else { //第二种情况,要继续匹配才能得到next[i]的值 29 int j = next[po] + po - i; 30 if(j < 0) { 31 j = 0; //如果i > po + next[po],则要从头开始匹配 32 } 33 while(i + j < len && str[j] == str[j + i]) { //计算next[i] 34 j++; 35 } 36 next[i] = j; 37 po = i; //更新po的位置 38 } 39 } 40 } 41 void extend_KMP(const char *str, int lens, const char *pattern, int lenp) { 42 get_next(pattern, lenp); // 先计算模式串的next数组 43 // 计算extend[0] 44 int i = 0; 45 while(str[i] == pattern[i] && i < lenp && i < lens) { 46 i++; 47 } 48 extend[0] = i; 49 int po = 0; // 初始化po的位置 50 for(i = 1; i < lens; i++) { 51 if(next[i - po] + i < extend[po] + po) { //第一种情况,直接可以得到extend[i]的值 52 extend[i] = next[i - po]; 53 } else { // 第二种情况,要继续匹配才能得到extend[i]的值 54 int j = extend[po] + po - i; 55 if(j < 0) { 56 j = 0; //如果i > extend[po] + po则要从头开始匹配 57 } 58 while(i + j < lens && j < lenp && str[j + i] == pattern[j]) { // 计算extend[i] 59 j++; 60 } 61 extend[i] = j; 62 po = i; // 更新po的位置 63 } 64 } 65 } 66 67 int main() { 68 int T; 69 scanf("%d", &T); 70 while (T--) { 71 scanf("%s%s", S1, S2); 72 int len1 = strlen(S1); 73 int len2 = strlen(S2); 74 std::reverse(S1, S1 + len1); 75 std::reverse(S2, S2 + len2); 76 extend_KMP(S1, len1, S2, len2); 77 memset(d, 0, sizeof(d)); 78 for (int i = 0; i < len1; i++) { 79 // printf("%d ", extend[i]); 80 d[extend[i]]++; 81 } 82 LL total = 0LL; 83 int ans = 0; 84 for (int j = len2; j > 0; j--) { 85 total = (total + d[j]) % MOD_NUM; 86 ans = (ans + total * j) % MOD_NUM; 87 } 88 printf("%d\n", ans); 89 // putchar('\n'); 90 } 91 return 0; 92 }
最后,在做此题的过程中我还突然发现输入外挂没用了,加上输入外挂后的执行时间比直接用scanf更长。也许是因为现在的oj用上了最新的编译器吧。