【笔记】后缀数组
参考资料:
题解 P3809 【【模板】后缀排序】 - xMinh 的博客 - 洛谷博客 (luogu.com.cn)
后缀数组
记号:
\(s_i\):字符串 \(s\) 的第 \(i\) 个后缀。
\(t_i\):排好序后第 \(i\) 个字符串。
\(LCP(s_i,s_j)\) 或 \(LCP(i,j)\):字符串 \(s\) 第 \(i\) 个和第 \(j\) 个后缀的 LCP 长度。
\(rk\):原数组到排好序的数组的映射。
\(sa\):排好序的后缀数组到原数组的映射。
\(height_i\):\(LCP(h_{i-1},h_i)\)
\(h_i\):\(height[sa[i]]\)
LCP lemma
首先证 \(RHS\ge LHS\),也就是相邻两个的 LCP 必然大于等于 \(s_i\) 和 \(s_j\) 的 LCP。
abcd
abcx
abcd
abcd
考虑上面的情形,abcx
与前后的 LCP 小于 4,于是它不可能在字典序中出现在该位置。
然后证存在 \(height_k=LCP(s_i,s_j)\),因为如果所有的 \(height_k\) 都大于 \(LCP(s_i,s_j)\),就会以一种传递性导致 \(LCP(s_i,s_j)\) 变大。
abcde
abcde
abcde
...
abcde
「相似度」引理
同理,有
abcdex
abcxxx
abcdef
反证法,如果存在 \(k<j\) 使得 \(LCP(s[rk[k]],s[rk[i]])>LCP(rk[j],rk[i])\),那么它必然在字典序中 \(i\) 和 \(j\) 的中间。
\(h[i]\ge h[i-1]-1\)
![[files/Pasted image 20211209143817.png]]
考虑第 \(i-1\) 个后缀在排好序数组中前一个字符串 \(k\),有
总之有 \(LCP(i,k+1)\ge h_{i-1}-1\). 考虑到 \(k\) 在字典序汇总排在 \(i-1\) 的前面,而 \(i\) 和 \(k+1\) 由 \(i-1\) 和 \(k\) 去掉首字符得到,所以 \(k+1\) 在字典序中也排在 \(i\) 的前面。由「相似度」引理知 \(h_i=LCP(i,sa[rk[i]-1]])\ge LCP(i,k+1)\ge h_{i-1}-1\).
于是按照 \(i=rk[1],rk[2],\dots,rk[n]\) 的顺序计算 \(h_i\),便可做到 \(O(n)\) 复杂度求 height.
代码
#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 5;
int n, m, cnt, t[N], sa[N], rk[2 * N], l[2 * N];
char c[N];
void buildSA() {
m = (int)'z';
for(int i = 1; i <= n; i++) ++t[rk[i] = c[i]];
for(int i = 1; i <= m; i++) t[i] += t[i - 1];
for(int i = n; i >= 1; i--) sa[t[rk[i]]--] = i;
for(int k = 1; k < n; k <<= 1) {
// 按第二关键字排序
for(int i = n - k + 1; i <= n; i++) l[cnt = i - n + k] = i;
for(int i = 1; i <= n; i++)
if(sa[i] > k) l[++cnt] = sa[i] - k;
// 按第一关键字排序
for(int i = 1; i <= m; i++) t[i] = 0;
for(int i = 1; i <= n; i++) ++t[rk[i]];
for(int i = 1; i <= m; i++) t[i] += t[i - 1];
for(int i = n; i >= 1; i--) sa[t[rk[l[i]]]--] = l[i], l[i] = 0;
// 处理新的 rk
swap(rk, l), cnt = 0;
for(int i = 1; i <= n; i++)
if(l[sa[i]] == l[sa[i - 1]] && l[sa[i] + k] == l[sa[i - 1] + k]) rk[sa[i]] = cnt;
else rk[sa[i]] = ++cnt;
if((m = cnt) == n) break;
}
// 处理最终的 rk
for(int i = 1; i <= n; i++) rk[sa[i]] = i;
}
void getHeight() {
for(int i = 1, k = 0; i <= n; i++) {
if(rk[i] == 1) continue;
if(k) --k;
for(int j = sa[rk[i] - 1]; j + k <= n && i + k <= n && c[i + k] == c[j + k]; ++k);
height[rk[i]] = k;
}
}
int main() {
ios::sync_with_stdio(false); cin.tie(nullptr);
scanf("%s", c + 1);
n = strlen(c + 1);
buildSA();
getHeight();
for(int i = 1; i <= n; i++) cout << sa[i] << ' ';
return 0;
}
例题
[[【题解】P2408 不同子串个数]]
[[【题解】P3975 弦论(TJOI2015)]]
P4094 字符串
佳媛姐姐过生日的时候,她的小伙伴从某东上买了一个生日礼物。生日礼物放在一个神奇的箱子中。箱子外边写了一个长为 \(n\) 的字符串 \(s\),和 \(m\) 个问题。佳媛姐姐必须正确回答这 \(m\) 个问题,才能打开箱子拿到礼物,升职加薪,出任 CEO,嫁给高富帅,走上人生巅峰。
每个问题均有 \(a,b,c,d\) 四个参数,问你子串 \(s[a..b]\) 的所有子串和 \(s[c..d]\) 的最长公共前缀的长度的最大值是多少?佳媛姐姐并不擅长做这样的问题,所以她向你求助,你该如何帮助她呢?
求最大值考虑二分答案。
判定存在一个长度为 \(k\) 的公共前缀,只需满足有 \(a\sim b-k+1\) 之间的后缀与 \(s[c..d]\) 的 LCP 长度大于等于 \(k\).
满足 LCP 长度条件的后缀可以在 SA 数组上二分找出来,得到一个区间。然后在这个区间里查是否有后缀的下标在 \(a\sim b-k+1\) 里面,用主席树解决即可。
二分找区间的操作需要取 height 的 \(\min\),并且要 \(O(1)\) 实现,得再写个 ST 表。
#include <bits/stdc++.h>
using namespace std;
// N2 为主席树节点数
const int N = 1e5 + 5, N2 = N * 20;
int n, m;
char c[N];
struct suffixArray {
int sa[N], t[N], ht[N], l[N * 2], rk[N * 2], cnt, m;
void build() {
// buildSA
m = int('z');
for(int i = 1; i <= n; i++) ++t[rk[i] = c[i]];
for(int i = 1; i <= m; i++) t[i] += t[i - 1];
for(int i = n; i >= 1; i--) sa[t[rk[i]]--] = i;
for(int k = 1; k < n; k <<= 1) {
for(int i = n - k + 1; i <= n; i++) l[cnt = i - n + k] = i; // i 不是 sa[i]
for(int i = 1; i <= n; i++) if(sa[i] > k) l[++cnt] = sa[i] - k;
for(int i = 1; i <= m; i++) t[i] = 0;
for(int i = 1; i <= n; i++) ++t[rk[l[i]]];
for(int i = 1; i <= m; i++) t[i] += t[i - 1];
for(int i = n; i >= 1; i--) sa[t[rk[l[i]]]--] = l[i], l[i] = 0;
swap(rk, l), cnt = 0;
for(int i = 1; i <= n; i++) {
// 注意是 rk[sa[i]]
if(l[sa[i - 1]] == l[sa[i]] && l[sa[i - 1] + k] == l[sa[i] + k]) rk[sa[i]] = cnt;
else rk[sa[i]] = ++cnt;
}
if((m = cnt) == n) break;
}
// getHeight
for(int i = 1, j, k = 0; i <= n; i++) {
if(rk[i] == 1) continue;
k -= (k > 0), j = sa[rk[i] - 1];
while(i + k <= n && j + k <= n && c[i + k] == c[j + k]) ++k;
ht[rk[i]] = k;
}
}
} sa;
struct segmentTree {
int rt[N], sum[N2], lc[N2], rc[N2], cnt;
void copy(int a, int b) { sum[b] = sum[a], lc[b] = lc[a], rc[b] = rc[a]; }
int modify(int p, int l, int r, int x) {
copy(p, ++cnt), p = cnt;
if(l == r) return ++sum[p], p;
int mid = (l + r) >> 1;
if(x <= mid) lc[p] = modify(lc[p], l, mid, x);
else rc[p] = modify(rc[p], mid + 1, r, x);
sum[p] = sum[lc[p]] + sum[rc[p]];
return p;
}
int query(int pl, int pr, int l, int r, int x, int y) {
if(l >= x && r <= y) return sum[pr] - sum[pl];
int mid = (l + r) >> 1, ret = 0;
if(x <= mid) ret += query(lc[pl], lc[pr], l, mid, x, y);
if(y > mid) ret += query(rc[pl], rc[pr], mid + 1, r, x, y);
return ret;
}
void build() {
for(int i = 1; i <= n; i++)
rt[i] = modify(rt[i - 1], 1, n, sa.sa[i]);
}
} sg;
struct sparseTable {
int lg2[N], t[18][N];
// height 的 rmq 的处理方式需要注意
// 我这里直接处理成全部是闭区间的形式
// 但是需要特殊化处理 t[1][],并且拼接的时候要多拼一点才是对的。。。这里调了好久
void build() {
for(int i = 2; i <= n; i <<= 1) ++lg2[i];
for(int i = 3; i <= n; i++) lg2[i] += lg2[i - 1];
for(int i = 1; i <= n; i++)
t[0][i] = n - sa.sa[i] + 1, t[1][i] = sa.ht[i + 1];
for(int i = 2; (1 << i) <= n; i++)
for(int j = 1; j + (1 << i) - 1 <= n; j++)
t[i][j] = min(t[i - 1][j], min(t[i - 1][j + (1 << (i - 1))], sa.ht[j + (1 << (i - 1))]));
}
int query(int l, int r) {
int k = lg2[r - l + 1];
return min(t[k][l], t[k][r - (1 << k) + 1]);
}
} st;
bool check(int a, int b, int c, int d, int k) {
int l = 1, r = sa.rk[c], mid, ans = -1, ql, qr;
while(l <= r) {
mid = (l + r) >> 1;
if(st.query(mid, sa.rk[c]) >= k) ans = mid, r = mid - 1;
else l = mid + 1;
}
ql = ans, l = sa.rk[c], r = n;
while(l <= r) {
mid = (l + r) >> 1;
if(st.query(sa.rk[c], mid) >= k) ans = mid, l = mid + 1;
else r = mid - 1;
}
qr = ans;
return (sg.query(sg.rt[ql - 1], sg.rt[qr], 1, n, a, b - k + 1) > 0);
}
int main() {
scanf("%d%d%s", &n, &m, c + 1);
sa.build(), sg.build(), st.build();
for(int i = 1, a, b, c, d; i <= m; i++) {
scanf("%d%d%d%d", &a, &b, &c, &d);
int l = 1, r = min(b - a + 1, d - c + 1), mid, ans = -1;
while(l <= r) {
mid = (l + r) >> 1;
if(check(a, b, c, d, mid)) ans = mid, l = mid + 1;
else r = mid - 1;
}
cout << (ans == -1 ? 0 : ans) << '\n';
}
return 0;
}