poj3415(后缀数组)
poj3415
题意
给定两个字符串,给出长度 \(m\) ,问这两个字符串有多少对长度大于等于 \(m\) 且完全相同的子串。
分析
首先连接两个字符串 A B,中间用一个特殊符号分割开。
按照 \(sa\) 的顺序(即枚举 \(height\) 值),进行分组,那么有公共前缀长大于等于 \(m\) 的都分到了一组,对于某一组,后缀串可能来自于 A 也可能来自于 B,那么对于 A 找前面的 B 串,对于 B 找前面的 A 串,如果某两个后缀串的公共前缀长为 \(l(l \geqslant m)\),那么显然会有 \(l - m + 1\) 对子串。
注意到这个性质: 对于两个后缀串 j 和 k,设 \(rnk[j] < rnk[k]\) ,LCP长度为 \(height[rnk[j]+1], height[rnk[j]+2], ... , height[rnk[k]]\) 中的最小值。
维护一个单调递增的栈(保证栈顶最大)可以用一个二维数组表示(\(q[][2]\)),一个是栈,一个是某个数的个数。
举个例子,如果连续的 \(height\) 值为 \(2 \ 3 \ 4\) ,\(m = 2\),前三个为 A 串,那么 \(2 \ 3 \ 4\) 全部入栈,且计算对答案的贡献 \(sum\)(不是直接加到答案上),即 \((2-2+1) + (3-2+1) + (4-2+1)\) ,到 B 串时,答案就加上了这个值,但是如果后面还有一个 B 串且 \(height\) 为 \(3\),那么就要弹栈,且减去 \(sum\) 值多的那部分(前面多算了),栈里 \(4\) 的数量为 \(1\),所以 \(sum = sum - (4 - 3) * 1\) ,且栈里 \(3\) 的数量变为了 \(2\) ( \(4\) 对应的 A 串对于后面串提供的贡献减小了(注意前面的性质),所以\(4\) 变为了 \(3\) ),答案加上 \(sum\)。
code
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int MAXN = 2e5 + 10;
const int INF = 1e9;
char s[MAXN];
int sa[MAXN], t[MAXN], t2[MAXN], c[MAXN], n; // n 为 字符串长度 + 1,即最后一位为数字 0
int rnk[MAXN], height[MAXN];
// 构造字符串 s 的后缀数组。每个字符值必须为 0 ~ m-1
void build_sa(int m) {
int i, *x = t, *y = t2;
for(i = 0; i < m; i++) c[i] = 0;
for(i = 0; i < n; i++) c[x[i] = s[i]]++;
for(i = 1; i < m; i++) c[i] += c[i - 1];
for(i = n - 1; i >= 0; i--) sa[--c[x[i]]] = i;
for(int k = 1; k <= n; k <<= 1) {
int p = 0;
for(i = n - k; i < n; i++) y[p++] = i;
for(i = 0; i < n; i++) if(sa[i] >= k) y[p++] = sa[i] - k;
for(i = 0; i < m; i++) c[i] = 0;
for(i = 0; i < n; i++) c[x[y[i]]]++;
for(i = 0; i < m; i++) c[i] += c[i - 1];
for(i = n - 1; i >= 0; i--) sa[--c[x[y[i]]]] = y[i];
swap(x, y);
p = 1;
x[sa[0]] = 0;
for(i = 1; i < n; i++)
x[sa[i]] = y[sa[i - 1]] == y[sa[i]] && y[sa[i - 1] + k] == y[sa[i] + k] ? p - 1 : p++;
if(p >= n) break;
m = p;
}
}
void getHeight() {
int i, j, k = 0;
for(i = 0; i < n; i++) rnk[sa[i]] = i;
for(i = 0; i < n - 1; i++) {
if(k) k--;
j = sa[rnk[i] - 1];
while(s[i + k] == s[j + k]) k++;
height[rnk[i]] = k;
}
}
char s2[MAXN];
int q[MAXN][2];
int main() {
int m;
while(~scanf("%d", &m) && m) {
scanf("%s%s", s, s2); // A 、B串
int l = strlen(s), l2 = strlen(s2);
s[l++] = '#';
for(int i = 0; i < l2; i++) s[i + l] = s2[i];
s[l + l2] = 0;
n = l + l2 + 1;
build_sa(128);
getHeight();
ll ans = 0, sum = 0;
int top = 0;
// 在 B 串前找 A
for(int i = 2; i < n; i++) {
int cnt = 0;
if(height[i] < m) {
top = 0; sum = 0;
continue;
}
if(sa[i - 1] < l) {
cnt++;
sum += height[i] - m + 1;
}
while(top && q[top - 1][0] >= height[i]) {
top--;
sum -= (q[top][0] - height[i]) * q[top][1];
cnt += q[top][1];
}
q[top][0] = height[i]; q[top++][1] = cnt;
if(sa[i] >= l) ans += sum;
}
// 在 A 串前找 B
sum = 0; top = 0;
for(int i = 2; i < n; i++) {
int cnt = 0;
if(height[i] < m) {
top = 0; sum = 0;
continue;
}
if(sa[i - 1] >= l) {
cnt++;
sum += height[i] - m + 1;
}
while(top && q[top - 1][0] >= height[i]) {
top--;
sum -= (q[top][0] - height[i]) * q[top][1];
cnt += q[top][1];
}
q[top][0] = height[i]; q[top++][1] = cnt;
if(sa[i] < l) ans += sum;
}
printf("%lld\n", ans);
}
return 0;
}