SPOJ 1812
题目概述
求\(n\)个字符串的最长公共子串。\(n\le 10\)。
题解
对于最短字符串建立SAM。
考虑将其他字符串放到这个SAM上匹配,并记录mx[i]表示以i为结尾的字符串最长匹配长度。
从根开始向后匹配。如果自动机上可以向后匹配字符,就走一条转一边,转移到下一个结点;否则跳parent树,知道存在这个字符的转移边为止。
如果永远没有,那么从根开始重新匹配。每匹配到一个节点,就更新mx[i]。
匹配完再对于每一个mx[fa[i]]从mx[i]更新。这是因为一个节点的par必然在他的儿子中出现,但却有可能没有匹配到par这个节点。
容易证明这样每个结点都有最长匹配值。
再记录mn[i]表示i结点在每一个字符串中的最长匹配长度的最小值。
那么所有结点的mn[i]就是答案。
注:之所以选取最短字符串建立SAM是因为每次匹配后都要遍历整个SAM并更新SAM每个节点的mx值,如果不选取最短串有可能复杂度并不是\(\mathcal O(\sum s_i)\)的
代码
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
const int N = 1e5 + 5, NODE = 2 * N;
int n, len[11], mnId, mx[NODE];
char str[11][N];
struct SAM {
int ch[NODE][26], len[NODE], fa[NODE], mx[NODE], las, t, buc[N], od[NODE], mn[NODE];
SAM() { las = t = 1; }
void ins(int c) {
int p = las, np;
fa[las = np = ++t] = 1;
len[t] = len[p] + 1;
for (; p && !ch[p][c]; p = fa[p])
ch[p][c] = np;
if (p) {
int q = ch[p][c], nq;
if (len[p] + 1 == len[q])
fa[np] = q;
else {
fa[nq = ++t] = fa[q];
len[t] = len[p] + 1;
memcpy(ch[nq], ch[q], sizeof(ch[nq]));
for (fa[np] = fa[q] = nq; ch[p][c] == q; p = fa[p])
ch[p][c] = nq;
}
}
}
void init(int le) {
for (int i = 1; i <= t; ++i)
++buc[len[i]];
for (int i = 1; i <= le; ++i) //这里 n并不是字符串长度...
buc[i] += buc[i - 1];
for (int i = 1; i <= t; ++i)
od[buc[len[i]]--] = i;
for (int i = 1; i <= t; ++i)
mn[i] = 0x3f3f3f3f;
}
void calc(char *s, int n) {
int now = 1, l = 0;
for (int i = 1; i <= t; ++i)
mx[i] = 0;
for (int i = 1; i <= n; ++i) {
while (now && !ch[now][s[i] - 'a'])
now = fa[now], l = len[now];
if (now) {
now = ch[now][s[i] - 'a'];
mx[now] = max(mx[now], ++l);
} else {
now = 1;
l = 0;
}
}
for (int i = 1; i <= t; ++i) {
mx[fa[od[i]]] = min(max(mx[fa[od[i]]], mx[od[i]]), len[fa[od[i]]]);
mn[od[i]] = min(mn[od[i]], mx[od[i]]);
}
}
} sam;
int main() {
mnId = 1;
while (~scanf("%s", str[++n] + 1)) {
len[n] = strlen(str[n] + 1);
if (len[mnId] > len[n])
mnId = n;
}
--n;
for (int i = 1; i <= len[mnId]; ++i)
sam.ins(str[mnId][i] - 'a');
sam.init(len[mnId]);
for (int i = 1; i <= n; ++i)
if (i != mnId)
sam.calc(str[i], len[i]);
int ans = 0;
for (int i = 1; i <= sam.t; ++i)
ans = max(ans, sam.mn[i]);
printf("%d\n", ans);
return 0;
}