后缀数组

后缀数组相关

后缀数组

基础的数组有两个:\(\text{sa}[]\)\(\text{rk}[]\)\(\text{sa}[i]\) 表示“第 \(i\) 小的后缀是第几个后缀”,\(\text{rk}[i]\) 表示“第 \(i\) 个后缀是第几小的”。这两个数组可以通过排序产生。

考虑双关键字倍增排序。初始化把所有长度为 \(1\) 的子串排个序,得到一个 \(\text{sa}[]\)\(\text{rk}_ 1[]\)。然后用 \(\text{rk}_ 1[i]\)\(\text{rk}_ 1[i+1]\) 作为当前 \(i\) 的关键字对长度为 \(2\) 的子串排序,得到一个 \(\text{sa}[]\)\(\text{rk}_ 2[]\)。然后用 \(\text{rk}_ 2[i]\)\(\text{rk}_ 2[i+2]\) 作为当前 \(i\) 的关键字对长度为 \(4\) 的子串排序,得到一个 \(\text{sa}[]\)\(\text{rk}_ 4[]\)……。倍增到最后得到一个后缀数组 \(\text{sa}[]\)\(\text{rk}_ n[]\)

使用基数排序,每轮排序的复杂度是 \(O(n)\) 的。由于进行了 \(O(\log n)\) 次,因此复杂度是 \(O(n \log n)\) 的。

Submission loj的设计就很优美

code
const int N = 1e6 + 10;
char ch[N];
int n, sa[N], rk[N];
void getsa(char ch[], int sa[], int rk[]) {
    int m = 127;
    static int cnt[N], rk2[N], key[N], id[N];
    memset(cnt+1, 0, m << 2);
    rep(i,1,n) rk[i] = ch[i], cnt[rk[i]]++;
    rep(i,1,m) cnt[i] += cnt[i-1];
    pre(i,n,1) sa[cnt[rk[i]] --] = i;
    for (int w = 1; ; w <<= 1) {
        int p = 0;
        for (register int i = n; i > n - w; -- i) id[++p] = i;
        rep(i,1,n) if (sa[i] > w) id[++p] = sa[i] - w;
        memset(cnt + 1, 0, m << 2);
        rep(i,1,n) key[i] = rk[id[i]], cnt[key[i]]++;
        rep(i,1,m) cnt[i] += cnt[i-1];
        pre(i,n,1) sa[cnt[key[i]] --] = id[i];
        memcpy(rk2+1, rk+1, n << 2);
        p = 0;
        rep(i,1,n) rk[sa[i]] = (
                        rk2[sa[i]] == rk2[sa[i-1]] and 
                        rk2[sa[i] + w] == rk2[sa[i-1] + w]
                     ) ? p : ++p;
        if (p == n) { rep(i,1,n) sa[rk[i]] = i; break; }
        m = p;
    }
}

当然你可以 \(O(n)\) 排序。有两种方法:DC3 和 SA-IS。DC3 难写常数还大,所以基本上没人用。这里介绍 SA-IS。
IS 是诱导排序,因为 cache 很友好所以常数小。而且码量不大

下面称从 \(i\) 位置开始的后缀为 后缀 \(i\)

我们在字符串最后加上去一个字典序小于字符串内任意字符的新字符 \(\#\)。将后缀分为两类:S 类和 L 类。有:

  1. 新加入的字符为 S 类。
  2. 若从 \(i\) 位置开始的后缀字典序严格小于从 \(i+1\) 开始的后缀,则从 \(i\) 位置开始的后缀为 S 类,反之为 L 类。

后缀的类型可以倒着扫一遍得到。具体地,若当前字符与后一位字符不同,则该后缀的类型可得;若相同则该后缀的类型与后一个后缀相同。

若后缀 \(i\) 为 S 类,后缀 \(j\) 为 L 类且字符 \(i =\) 字符 \(j\),则后缀 \(i\) 的字典序大于后缀 \(j\)

然后排序。我们定义 lms(Left-Most-S-suffix) 串为左边是一个 L 型串的 S 型串。显然有 lms 串的数量最多是 \(\frac n2\) 的。定义 lms 子串为两个 lms 串中间(包含两端)的串。SA-IS 通过把原串拆成 lms 子串的方式缩小求解范围。

我们递归求得下层的 \(\text{sa}[]\),现在需要求得本层的 \(\text{sa}[]\)。使用诱导排序。步骤:

  1. 使用两部分桶记录 S 串和 L 串的数量
  2. 倒序扫 lms,放进它的 sa 对应的 S 型桶里
  3. 正序扫 sa,如果 sa - 1 位置的类型是 L 型就放进对应的 L 型桶里
  4. 倒序扫 sa,如果 sa - 1 位置的类型是 S 型就放进对应的 S 型桶里

有结论说明 lms 子串进行一遍诱导排序后就有序了。
然后考虑如何对 lms 离散化。如果首位不同则编号,相同则暴力匹配。

  1. 确定每个后缀的类型,找出lms类型后缀。
  2. 第一遍诱导排序,对lms排序。
  3. 对lms离散化。
  4. 如果lms互不相同,直接计算sa,否则递归。
  5. 恢复lms的位置,诱导排序一遍计算sa。

这里是一份实现:

code
template <typename _Tp = char, const int siz = (1 << 8 << sizeof(_Tp)) - 1>
struct SA {
	_Tp ch[N];
	int n = -1, sa[N], rk[N], heg[N];
	int a[N], tr[siz + 5], cur[N];
	int sum[N], bse[N << 3], *_t = bse;

	_Tp* begin() { return ch + 1; }
	const _Tp operator [] (const int & p) const { return ch[p]; }
	_Tp & operator [] (const int & p) { return ch[p]; }

	#define _MemAlloc(pool, size, tag) (tag = pool, pool += size)
	#define pushs(x) (sa[cur[a[x]] --] = x)
	#define pushi(x) (sa[cur[a[x]] ++] = x)
	#define inds(lms)  		\
		rep(i,1,n) sa[i] = -1, sum[i] = 0; \
		rep(i,1,n) sum[a[i]] ++; \
		rep(i,1,n) sum[i] += sum[i - 1]; \
		rep(i,1,n) cur[i] = sum[i]; \
		pre(i,m,1) pushs(lms[i]); \
		rep(i,1,n) cur[i] = sum[i - 1] + 1; \
		rep(i,1,n) if (sa[i] > 1 and !tp[sa[i] - 1]) pushi(sa[i] - 1); \
		rep(i,1,n) cur[i] = sum[i]; \
		pre(i,n,1) if (sa[i] > 1 and tp[sa[i] - 1]) pushs(sa[i] - 1);

	inline void SA_IS(int n, int* a) {
		int* tp; _MemAlloc(_t, n + 1, tp); tp[n] = 1;
		int* p; _MemAlloc(_t, n + 2, p);
		pre(i,n-1,1) tp[i] = (a[i] == a[i + 1]) ? tp[i + 1] : (a[i] < a[i + 1]);
		int m = 0, tot = 0;
		rep(i,1,n) rk[i] = (tp[i] and !tp[i - 1]) ? (p[++ m] = i, m) : -1;
		inds(p);
		int* a1; _MemAlloc(_t, m + 1, a1);
		p[m + 1] = n;
		for (int i = 1, x, y; i <= n; ++ i) if ((x = rk[sa[i]]) != -1) {
			if (tot == 0 or p[x + 1] - p[x] != p[y + 1] - p[y]) ++ tot;
			else for (int p1 = p[x], p2 = p[y]; p2 <= p[y + 1]; ++ p1, ++ p2)
				if ((a[p1] << 1 | tp[p1]) != (a[p2] << 1 | tp[p2])) { ++ tot; break; }
			a1[y = x] = tot;
		}
		if (tot == m) rep(i,1,m) sa[a1[i]] = i;
		else SA_IS(m, a1);
		rep(i,1,m) a1[i] = p[sa[i]];
		inds(a1);
	}

	int st[N][20], lgv[N];
	inline void build() {
		if (is_same<_Tp, char>::value) n = strlen(begin());

		memset(tr, 0, sizeof tr);
		rep(i,1,n) tr[ch[i]] = 1;
		rep(i,1,siz + 1) tr[i] += tr[i - 1];
		rep(i,1,n) a[i] = tr[ch[i]] + 1;
		a[n + 1] = 1; _t = bse;
		SA_IS(n + 1, a); 
		rep(i,1,n) sa[i] = sa[i + 1];
		rep(i,1,n) rk[sa[i]] = i;
		
		for (int i = 1, k = 0; i <= n; ++ i) {
			if (rk[i] == 0) continue;
			if (k) -- k;
			while (ch[i + k] == ch[sa[rk[i] - 1] + k]) ++ k;
			heg[rk[i]] = k;
		}

		rep(i,1,n) st[i][0] = heg[i];
		rep(i,2,n) lgv[i] = lgv[i >> 1] + 1;
		rep(i,1,lgv[n]) for (int j = 1; j + (1 << i) - 1 <= n; ++ j) 
			st[j][i] = min(st[j][i - 1], st[j + (1 << i - 1)][i - 1]);
	}

	inline int lcp(int l, int r) {
		if (l <= 0 or l > n or r <= 0 or r > n) return 0;
		l = rk[l], r = rk[r];
		if(l > r) swap(l, r); l++; 
		int k = lgv[r - l + 1]; 
		return min(st[l][k], st[r - (1 << k) + 1][k]); 
	}
}; 

height

\(height[i] = \text{lcp}(\text{sa}[i],\text{sa}[i-1])\)

引理:\(height[\text{rk}[i]] \ge height[\text{rk}[i-1]]-1\)
然后我们就可以顺着 \(\text{rk}\) 数组扫一遍得到答案。

void getheight() {
    for (int i = 1, k = 0; i <= n; i++) {
        if (rk[i] == 0) continue;
        if (k) k--;
        while (s[i + k] == s[sa[rk[i] - 1] + k]) k++;
        height[rk[i]] = k;
    }
}

定理:\(\text{lcp}(sa[i],sa[j]) = \min_{k=i}^j height[k]\)

所以可以在 height 数组上建一个st表得到lcp。

然后放一份正常的 \(O(n\log n)\) 代码:

code
template <typename _Tp = char, const int siz = 127>
struct SA {
	_Tp ch[N];
	int n = -1, sa[N], rk[N], heg[N];
	int cnt[N], rk2[N], key[N], id[N];

	_Tp* begin() { return ch + 1; }
	const _Tp operator [] (const int & p) const { return ch[p]; }
	_Tp & operator [] (const int & p) { return ch[p]; }

	void getsa() {
		int m = siz;
		memset(cnt + 1, 0, sizeof(*cnt) * m);
		rep(i,1,n) rk[i] = ch[i], cnt[rk[i]]++;
		rep(i,1,m) cnt[i] += cnt[i - 1];
		pre(i,n,1) sa[cnt[rk[i]] --] = i;
		for (int w = 1, p; ; w <<= 1) {
			p = 0;
			pre(i,n,n-w+1) id[++ p] = i;
			rep(i,1,n) if (sa[i] > w) id[++ p] = sa[i] - w;
			memset(cnt + 1, 0, sizeof(*cnt) * m);
			rep(i,1,n) key[i] = rk[id[i]], cnt[key[i]] ++;
			rep(i,1,m) cnt[i] += cnt[i - 1];
			pre(i,n,1) sa[cnt[key[i]] --] = id[i];
			memcpy(rk2 + 1, rk + 1, sizeof(*rk2) * n);
			p = 0;
			rep(i,1,n) rk[sa[i]] = (
					rk2[sa[i]] == rk2[sa[i - 1]] and 
					rk2[sa[i] + w] == rk2[sa[i - 1] + w]
					) ? p : ++ p;
			if (p == n) { rep(i,1,n) sa[rk[i]] = i; break; }
			m = p;
		}
	}

	void getheg() {
		for (int i = 1, k = 0; i <= n; ++ i) {
			if (rk[i] == 0) continue;
			if (k) -- k;
			while (ch[i + k] == ch[sa[rk[i] - 1] + k]) ++ k;
			heg[rk[i]] = k;
		}
	}

	int st[N][20], lgv[N];
	void build() {
		if (n <= 0) n = strlen(begin());
		getsa(); getheg();
		rep(i,1,n) st[i][0] = heg[i];
		rep(i,2,n) lgv[i] = lgv[i >> 1] + 1;
		rep(i,1,lgv[n]) for (int j = 1; j + (1 << i) - 1 <= n; ++ j) 
			st[j][i] = min(st[j][i - 1], st[j + (1 << i - 1)][i - 1]);
	}

	int query(int l, int r) {
		if (l == r) return n + 1;
		if (l > r) swap(l, r); ++ l;
		int d = lgv[r - l + 1];
		return min(st[l][d], st[r - (1 << d) + 1][d]);
	}
}; 

例题

posted @ 2022-10-09 21:04  joke3579  阅读(67)  评论(1编辑  收藏  举报