hdu6153后缀数组或扩展KMP

前两天刷了几题leetcode,感觉挺简单,于是又想刷刷hduoj了。随便打开没做过的一页,找了一题通过人数最多的,就是这道6153.

①.看完题没想太多,觉得应该是后缀数组(多年没刷题的我字符串这一块对后缀数组记忆最深吧),因为S1和S2长度都一百万,n^2受不了。nlogn应该行。
②.用后缀数组的话,需要会用后缀数组求子串出现次数。如果是任意子串,还不太好办,但是这里的子串只会是后缀,那就好办了——只需看此后缀在所有后缀中的前缀的数量,也即与本后缀的LCP等于len(s)的后缀数量。因为后缀数组是排过序的,只要往后看,直到height值小于len(s)为止。
③.会了②,就好办了。把S1和S2连起来(中间插入一个比如~号,记为S),求出S的后缀数组SA;又对S2求出后缀数组SA2。
对于每一个S2的后缀,用在SA中求出的出现次数减去SA2中求出的出现次数,就是此后缀在S1中的出现次数(因为它不会在跨S1和S2时出现)。
④.复杂度分析。构造后缀数组最优算法是O(n),用倍增算法构造后缀数组是O(nlogn)。后面的计算步骤,因为S2长度n为一百万,要计算n次,每次用②中的办法往后找,最坏情况要O(n)。但是不意味着这里的总复杂度只能O(n^2)。因为后缀数组是排好序的,按后缀数组的逆序处理这n次查找,那么第k次查找可以充分利用第k-1次查找的结果,往后滑动。也就是说这n次查找基本不会重叠,所以最后查找计算部分的总复杂度仍为O(n)。
这个地方我举个例子,以字符串aabaaaab~aabaaaab为例,它有17个后缀,在后缀数组中的顺序为:
aaaab
aaaab~aabaaaab
aaab
aaab~aabaaaab
aab
aabaaaab
aabaaaab~aabaaaab
aab~aabaaaab
ab
abaaaab
abaaaab~aabaaaab
ab~aabaaaab
b
baaaab
baaaab~aabaaaab
b~aabaaaab
~aabaaaab
求出后缀baaaab的出现次数为2以后,再求后缀b的出现次数时,因为知道b是baaaab的子串(height值等于b的长度),可以直接滑到b~aabaaaab进行判断。
那最后整个算法的复杂度还是卡在构造后缀数组部分,如果用倍增,那整个算法的复杂度为O(nlogn)
⑤.具体实现上,因为我要复用两次后缀数组的代码,所以每次预处理出data数组,再处理出d数组。
d数组就是逆序存的每个后缀在整个串中出现的次数。
代码如下:

  1 /*
  2  * Author    : ben
  3  */
  4 #include <cstdio>
  5 #include <cstdlib>
  6 #include <cstring>
  7 typedef long long LL;
  8 const int MAXN = 2010000;
  9 char s[MAXN];
 10 int sa[MAXN], height[MAXN], rank[MAXN], N;
 11 int tmp[MAXN], top[MAXN];
 12 void makesa() {
 13     int i, j, len, na;
 14     na = (N < 256 ? 256 : N);
 15     memset(top, 0, na * sizeof(int));
 16     for (i = 0; i < N; i++) {
 17         top[rank[i] = s[i] & 0xff]++;
 18     }
 19     for (i = 1; i < na; i++) {
 20         top[i] += top[i - 1];
 21     }
 22     for (i = 0; i < N; i++) {
 23         sa[--top[rank[i]]] = i;
 24     }
 25     for (len = 1; len < N; len <<= 1) {
 26         for (i = 0; i < N; i++) {
 27             j = sa[i] - len;
 28             if (j < 0) {
 29                 j += N;
 30             }
 31             tmp[top[rank[j]]++] = j;
 32         }
 33         sa[tmp[top[0] = 0]] = j = 0;
 34         for (i = 1; i < N; i++) {
 35             if (rank[tmp[i]] != rank[tmp[i - 1]]
 36                     || rank[tmp[i] + len] != rank[tmp[i - 1] + len]) {
 37                 top[++j] = i;
 38             }
 39             sa[tmp[i]] = j;
 40         }
 41         memcpy(rank, sa, N * sizeof(int));
 42         memcpy(sa, tmp, N * sizeof(int));
 43         if (j >= N - 1) {
 44             break;
 45         }
 46     }
 47 }
 48 
 49 void lcp() {
 50     int i, j, k;
 51     for (j = rank[height[i = k = 0] = 0]; i < N - 1; i++, k++) {
 52         while (k >= 0 && s[i] != s[sa[j - 1] + k]) {
 53             height[j] = (k--), j = rank[sa[j] + 1];
 54         }
 55     }
 56 }
 57 
 58 char S1[MAXN], S2[MAXN];
 59 int data1[MAXN], data2[MAXN];
 60 int d[MAXN];
 61 
 62 void makedata(int *data) {
 63     data[0] = 1;
 64     for (int i = N - 2; i > 0; i--) {
 65         int leni = N - 1 - sa[i];
 66         int j = N - i - 1;
 67         if (height[i + 1] < leni) {
 68             data[j] = 1;
 69         } else {
 70             int k = i + data[j - 1] + 1;
 71             while (k < N && height[k] >= leni) {
 72                 k++;
 73             }
 74             data[j] = k - i;
 75         }
 76 //        cout << data[j] << endl;
 77     }
 78 }
 79 
 80 const LL MOD_NUM = 1000000007LL;
 81 int work(int lens1, int lens2) {
 82     int ans = 0;
 83     strcpy(s, S2);
 84     N = lens2 + 1;
 85     makesa();
 86     lcp();
 87     makedata(data1);
 88     for (int i = 0; i < lens2; i++) {
 89         int j = lens2 - i - 1;
 90         d[i] = data1[lens2 - rank[j]];
 91 //        cout << d[i] << endl;
 92     }
 93     strcpy(s, S1);
 94     s[lens1] = '#';
 95     s[lens1 + 1] = 0;
 96     strcat(s, S2);
 97     N = lens1 + lens2 + 2;
 98     makesa();
 99     lcp();
100     makedata(data2);
101     for (int i = 0; i < lens2; i++) {
102         int j = N - i - 2;
103         d[i] = data2[N - 1 - rank[j]] - d[i];
104         ans = ans % MOD_NUM + ((i + 1LL) * d[i]) % MOD_NUM;
105 //        cout << d[i] << endl;
106     }
107     return ans % MOD_NUM;
108 }
109 
110 int main() {
111     int T;
112     scanf("%d", &T);
113     while (T--) {
114         scanf("%s%s", S1, S2);
115         int lens1 = strlen(S1);
116         int lens2 = strlen(S2);
117         printf("%d\n", work(lens1, lens2));
118     }
119     return 0;
120 }

 

然而,提交上去,超时了看来数据量很强。是卡在后缀数组构造的倍增算法上了但是我手头没有DC3等更优算法的模板。

转念一想,这题通过的人这么多,不可能需要高阶后缀数组算法的。于是回忆了一下还有别的什么算法。想到了扩展KMP。

用扩展KMP解此题的思路主要是要把S1和S2逆序,逆序以后,题目要求的S2的每个后缀就变成前缀了。根据extend数组的定义我们知道,如果extend[i] = x,则表示S2的前x个字符与S1从i开始的x个字符相同,统计extend数组中有多少个x就知道S2的这个前缀在S1中出现的次数。这种统计是可以在线性时间完成的,而前面生成extend数组的时间也为线性,故最后整体复杂度也为线性O(N)。代码如下

 1 /*
 2  * Author    : ben
 3  */
 4 #include <cstdio>
 5 #include <cstdlib>
 6 #include <cstring>
 7 #include <cmath>
 8 #include <ctime>
 9 #include <algorithm>
10 typedef long long LL;
11 const int MAXN = 1001000;
12 char S1[MAXN], S2[MAXN];
13 int d[MAXN];
14 const LL MOD_NUM = 1000000007LL;
15 int next[MAXN], extend[MAXN];
16 void get_next(const char *str, int len){
17     // 计算next[0]和next[1]
18     next[0] = len;
19     int i = 0;
20     while(str[i] == str[i + 1] && i + 1 < len) {
21         i++;
22     }
23     next[1] = i;
24     int po = 1; //初始化po的位置
25     for(i = 2; i < len; i++) {
26         if(next[i - po] + i < next[po] + po) { //第一种情况,可以直接得到next[i]的值
27             next[i] = next[i - po];
28         } else { //第二种情况,要继续匹配才能得到next[i]的值
29             int j = next[po] + po - i;
30             if(j < 0) {
31                 j = 0; //如果i > po + next[po],则要从头开始匹配
32             }
33             while(i + j < len && str[j] == str[j + i]) { //计算next[i]
34                 j++;
35             }
36             next[i] = j;
37             po = i; //更新po的位置
38         }
39     }
40 }
41 void extend_KMP(const char *str, int lens, const char *pattern, int lenp) {
42     get_next(pattern, lenp); // 先计算模式串的next数组
43     // 计算extend[0]
44     int i = 0;
45     while(str[i] == pattern[i] && i < lenp && i < lens) {
46         i++;
47     }
48     extend[0] = i;
49     int po = 0; // 初始化po的位置
50     for(i = 1; i < lens; i++) {
51         if(next[i - po] + i < extend[po] + po) { //第一种情况,直接可以得到extend[i]的值
52             extend[i] = next[i - po];
53         } else { // 第二种情况,要继续匹配才能得到extend[i]的值
54             int j = extend[po] + po - i;
55             if(j < 0) {
56                 j = 0; //如果i > extend[po] + po则要从头开始匹配
57             }
58             while(i + j < lens && j < lenp && str[j + i] == pattern[j]) { // 计算extend[i]
59                 j++;
60             }
61             extend[i] = j;
62             po = i; // 更新po的位置
63         }
64     }
65 }
66 
67 int main() {
68     int T;
69     scanf("%d", &T);
70     while (T--) {
71         scanf("%s%s", S1, S2);
72         int len1 = strlen(S1);
73         int len2 = strlen(S2);
74         std::reverse(S1, S1 + len1);
75         std::reverse(S2, S2 + len2);
76         extend_KMP(S1, len1, S2, len2);
77         memset(d, 0, sizeof(d));
78         for (int i = 0; i < len1; i++) {
79 //            printf("%d ", extend[i]);
80             d[extend[i]]++;
81         }
82         LL total = 0LL;
83         int ans = 0;
84         for (int j = len2; j > 0; j--) {
85             total = (total + d[j]) % MOD_NUM;
86             ans = (ans + total * j) % MOD_NUM;
87         }
88         printf("%d\n", ans);
89 //        putchar('\n');
90     }
91     return 0;
92 }

最后,在做此题的过程中我还突然发现输入外挂没用了,加上输入外挂后的执行时间比直接用scanf更长。也许是因为现在的oj用上了最新的编译器吧。

 

posted @ 2022-12-06 08:05  moonbay  阅读(44)  评论(0编辑  收藏  举报