POJ 3415 Common Substrings

题目大意:

给两个长度不超过100000的字符串A和B, 求三元组(i, j, k)的数量, 即满足A串从i开始的后缀与B串从j开始的后缀有长度为k的公共前缀, 还要求k不小于某个给你的数K.

 

简要分析:

如果i后缀与j后缀的LCP长度为L, 在L不小于K的情况下, 它对答案的贡献为L - K + 1. 于是我们可以将两个串连起来, 中间加个奇葩的分隔符, 做一遍后缀数组, 并按height数组的值对后缀分组, 保证同组内的后缀间的LCP不小于K. 显然不同组间的答案是独立的, 我们可以单独处理每一组. 于是问题变成了: 在每一组内, 对每个A的后缀, 算出它之前B的后缀与之LCP的和(其实是LCP - K + 1, 后面都说成LCP), 再对每个B后缀, 算出它之前A的后缀与之LCP的和, 这里我们考虑前者. 利用height数组求LCP时是对一段区间求最小值, 所以某个A后缀前的B后缀与之的LCP是单调不减的, 于是我们维护一个单调递增的栈, 每个元素记录height值和该后缀的下标, 那么若这个元素与前一个元素下标之间有x个B后缀, 则它对答案的贡献就是x * height. 这样维护一遍可以做到O(N), N为A和B的总长度. 这道题感觉很巧妙啊!

 

代码实现:

View Code
 1 #include <cstdio>
2 #include <cstdlib>
3 #include <cstring>
4 #include <algorithm>
5 using namespace std;
6
7 typedef long long val_t;
8 const char split[] = "#";
9 const int MAX_L = 100000, MAX_N = MAX_L * 2 + 1;
10 char a[MAX_L + 2], b[MAX_L + 2], c[MAX_N + 2];
11 int m, sza, szb, n, sa[MAX_N + 1], rank[MAX_N + 1], height[MAX_N + 1], cnt[MAX_N + 1];
12 val_t ans;
13 int tcnt[MAX_N + 1];
14
15 struct node_t {
16 int v[2], p;
17 bool operator == (const node_t &t) const {
18 return v[0] == t.v[0] && v[1] == t.v[1];
19 }
20 } nd[MAX_N + 1], tp[MAX_N + 1];
21
22 struct mono_stack_t {
23 val_t v[MAX_N + 1][4];
24 int stop;
25 mono_stack_t() { stop = 0; }
26 void clear() { stop = 0; }
27 void push(int p, int x) {
28 while (stop && x <= v[stop][0]) stop --;
29 stop ++;
30 v[stop][0] = x, v[stop][1] = p;
31 v[stop][2] = v[stop - 1][2] + (val_t)(tcnt[p] - tcnt[v[stop - 1][1]]) * x;
32 v[stop][3] = v[stop - 1][3] + (val_t)(p - v[stop - 1][1]) * x;
33 }
34 val_t ask(bool t) {
35 if (t) return v[stop][2];
36 else return v[stop][3] - v[stop][2];
37 }
38 } st;
39
40 void ra(int b) {
41 for (int i = 1; i >= 0; i --) {
42 memset(cnt, 0, sizeof(int) * (b + 1));
43 for (int j = 1; j <= n; j ++) cnt[nd[j].v[i]] ++;
44 for (int j = 1; j <= b; j ++) cnt[j] += cnt[j - 1];
45 for (int j = n; j >= 1; j --) tp[cnt[nd[j].v[i]] --] = nd[j];
46 memcpy(nd, tp, sizeof(node_t) * (n + 1));
47 }
48 for (int i = 1, j = 1, k = 1; i <= n; i = j, k ++)
49 while (j <= n && nd[j] == nd[i]) rank[nd[j ++].p] = k;
50 }
51
52 int main() {
53 while (scanf("%d", &m) != EOF && m) {
54 scanf("%s%s", a + 1, b + 1);
55 sza = strlen(a + 1), szb = strlen(b + 1);
56 strcpy(c + 1, a + 1);
57 strcat(c + 1, split);
58 strcat(c + 1, b + 1);
59 n = strlen(c + 1);
60 for (int i = 1; i <= n; i ++) {
61 nd[i].v[0] = c[i], nd[i].v[1] = 0;
62 nd[i].p = i;
63 }
64 ra(255);
65 for (int s = 1; s < n; s <<= 1) {
66 for (int i = 1; i <= n; i ++) {
67 nd[i].v[0] = rank[i], nd[i].v[1] = i + s <= n ? rank[i + s] : 0;
68 nd[i].p = i;
69 }
70 ra(n);
71 }
72 for (int i = 1; i <= n; i ++) sa[rank[i]] = i;
73 for (int i = 1, j, k = 0; i <= n; height[rank[i ++]] = k)
74 for (k ? k -- : 0, j = sa[rank[i] - 1]; c[i + k] == c[j + k]; k ++);
75
76 ans = 0LL;
77 for (int i = 1; i <= n; ) {
78 int j = i;
79 while (j + 1 <= n && height[j + 1] >= m) j ++;
80
81 tcnt[0] = 0;
82 for (int k = i, p = 0; k <= j; k ++, p ++) tcnt[p + 1] = tcnt[p] + (sa[k] > sza + 1);
83 st.clear();
84 for (int k = i + 1, p = 0; k <= j; k ++) {
85 st.push(++ p, height[k] - m + 1);
86 if (sa[k] <= sza) ans += st.ask(1);
87 else ans += st.ask(0);
88 }
89
90 i = j + 1;
91 }
92
93 printf("%lld\n", ans);
94 }
95 return 0;
96 }
posted @ 2012-03-10 19:50  zcwwzdjn  阅读(1636)  评论(0编辑  收藏  举报