倍增法求后缀数组

0. 前言

请保证在学习本文之前学习过倍增算法。

本文适合有一定字符串基础的OIer学习。如果你没有字符串基础,那么你就会看不懂其中的一些内容。可以在下方评论区提出或私信联系作者。

1. 后缀数组

后缀数组是处理字符串的有力工具 —罗穗骞

说白点,后缀数组包含了两个东西:

  1. 后缀数组sa,也就是suffix array的缩写
    \(sa_i\)就表示排名为\(i\)的后缀的起始位置。
  2. 排名数组rnk,也就是rank的缩写
    \(rnk_i\)就表示起始位置为\(i\)的后缀的排名

根据定义,不难发现,

sa[rnk[i]]=i,起始位置为\(i\)的后缀排名为\(rnk_i\),那么排名为\(rnk_i\)的后缀的起始位置就是\(i\)

rnk[sa[i]]=i,排名为\(i\)的后缀的起始位置为\(sa_i\),那么起始位置为\(sa_i\)的后缀的排名就是\(i\)

所以他们两个可以相互转换

解释一下其中的概念:

1. 后缀

对于一个字符串str[0:n)来说(为了方便表示,这里字符串采用区间表示子串,例如str[0:n)就表示从0开始到n的一个左闭右开的子串,起始就是str串本身)

他的后缀就是从某一位开始一直数到最后的子串str[i:n)\(i\in[0,n-1)\)可以发现最小的后缀只有一个字符str[n-1],最大的后缀就是原串str[0:n)本身。

那么就是说,我们只需要知道后缀的起始位置就可以知道一整个后缀。

suf[i]表示原串str[0:n)的后缀str[i:n)。可以发现suf[i]是一个字符串类型,suf[i][0]才是字符类型。

举个例子,字符串str[0:7)="abcdefg",他的后缀有

suf[0]="abcdefg"

suf[1]="bcdefg"

suf[2]="cdefg"

suf[3]="defg"

suf[4]="efg"

suf[5]="fg"

suf[6]="g"

后缀数组就是在这些后缀的基础上,求出他们的大小关系。

2.大小

那么先定义字符串大小关系:

对于字符串\(a,b\)来说,大小关系如下(通俗版):

  1. 一位位从前往后比较,如果某一位不相同,那么这一位上ASCII小的那个字符串小
  2. 如果某一位上一个字符串的字符为空(不是空格,就是空,意味着这个字符串到结尾了),而另一个字符串不为空,那么定义这个为空的小。
  3. 如果两个字符串同时为空,那么定义这两个字符串相等

继续举例子,观察:

"abc"=="abc"(废话)

"ab"<"abc"

"abc"<"b"

观察第二个式子,其实我们可以看做我们在"ab"后面插入一个字符,一个极小的字符,他的ASCII是0,那么"ab\0"<"abc"就可以知道了。

同样,因为我们的后缀长度都不一样,所以我们可以考虑在短的后缀后面插入一些"\0"是的他们长度相同。

得到了一个大小关系的定义之后,我们很容易发现,比较长度为\(n_1,n_2\)的两个字符串最劣的情况需要\(\min(n_1,n_2)\)的时间复杂度。然后我们如果直接暴力排序的话是需要\(n^2\log_2n\)的时间复杂度才能把后缀数组求出来。

一般算法到了\(n^2\)级别就可以舍弃了。。。dij不是都到了\(n\log_2n\)了吗

因此我们要想办法优化。

这个办法怎么想出来的,我也不知道

但是我可以解释给你听这个算法怎么工作,以及他的正确性。

2. 后缀数组

我也不知道为什么两个标题都叫做后缀数组

对于长度相同的字符串的大小关系,我们可以这样看:

假设字符串\(a,b\)长度均为\(n\),且存在正整数\(m<n\),那么我们先比较a[0:m),b[0:m)的大小,如果比较不出来再比较a[m,n),b[m:n)的大小。

看上去像是一个很废物的性质。

但是他为我们的倍增算法提供了新思路。

如果\(a,b\)都是str的后缀,设a=suf[c],b=suf[d]

如果我们手里有str[c:c+m),str[d:d+m)以及str[c+m:n),str[d+m:n)的大小关系,我们就可以得出\(a,b\)的大小关系。这里暂且不讨论越界的情况。

有一个很明显的,当\(m=1\)的时候,str[c:c+m),str[d:d+m)的大小关系其实就是他们ASCII码的大小关系。

所以我们就要找str[c+1:n),str[d+1:n)的大小关系就可以了。

那么这个str[c+1:n),str[d+1:n)的大小关系似乎也可以通过类似上面的方法拆成str[c+1:c+2),str[d+1:d+2)以及str[c+2:n),str[d+2:n)

边界?

由于我们在短的后缀后面插入了一堆"\0"使得他们的长度均为\(n\),因此我们可以得知\(m\)的最大值可以到\(n\),但是在\(n\)之前就可以比较得出答案了。因为后缀原长不同,因此末尾插入的连续"\0"的个数也一定不相同。那么只要比较到某一个是"\0"的时候就可以得出大小关系了。

不过这样每次一个个跳真的好慢。。。

联想一下我们倍增的方法,不妨设\(m\)为最大的,使得\(m< n\)成立的2的正整数次幂(其实就是\(m=2^{\lfloor\log_2(n-1)\rfloor}\)

那么我们先比较出str[c,c+m),str[d,d+m)的大小关系,然后再比较剩余部分就可以了(越界在后面补"\0")。

那么,我们不妨玩大一点。

我们在所有后缀后面都补一定量的"\0",使得他们的长度均为\(n'=2^{\lceil\log_2n\rceil}\)也就是变成了一个2的正整数次幂。

那么我们的\(m'=2^{\lfloor\log_2(n'-1)\rfloor}\)就满足\(2m'=n'\)

也就是说我们字符串分成的str[c,c+m'),str[c+m',c+n')两部分的长度应该是一样的,均为\(m'\),在算上了我们补的"\0"之后。

然后\(m'\)也是一个2的正整数次幂。

所以我们又可以引用上面的方法,只要我们能比较出str[c,c+m'/2),str[c+m'/2,c+m)就可以得到前半部分的线索。

以此类推。

边界?

\(m'\)不是\(2\)的正整数次幂的时候,那么就应该是\(2^0=1\)

\(m\)为1的时候,我们讨论过了,直接就是ASCII

因此我们可以从\(m=1\)的情况一路倍增回去

倍增到最后,我们就可以求出在补了"\0"之后的后缀们的相对大小关系。

退一步,在最后一次倍增执行之前,我们的sa数组存的就是长度为\(n'/2\)的子串的大小关系。

退回来,在第一步,我们的rnk数组就应该存下长度为\(1\)的子串的大小关系,就是他们的ASCII值。

于是一开始的基本框架我们可以搭建起来了。

#include <cstdio>
#include <cstring>
char str[1000001];
int n,sa[1000001],rnk[1000001];
int main()
{
	scanf("%s",str);
	n=strlen(str);
	for(register int i=0;i<n;++i) rnk[i]=str[i];
	// do sth.
	return 0;
}

在开始我们的倍增之前,还要好好说说这个排序。

我们前面一直在鼓吹我们对于str[c:c+n'),str[d:d+n')的大小关系比较可以拆成str[c:c+n'/2),str[d:d+n'/2)以及str[c+n'/2:c+n'),str[d+n'/2:d+n')。这是两个字符串的比较。

我们现在是在求长度为\(n'\)的子串的排名信息。

那么我们倍增上来就应该已经求出了\(n'/2\)的排名信息,储存在当前的sa,rnk中。

我们可以通过rnk[c]得之str[c:c+n'/2]在所有长度为\(n'/2\)的串中的排名。同样rnk[d]就表示str[d:d+n'/2)的排名。如果rnk[c]!=rnk[d]那么就可以直接比较出来两个串 的大小关系。

如果rnk[c]==rnk[d]那么我们就去比较rnk[c+m'],rnk[d+m']如果他们不相等也可以得到两个串的大小关系。如果他们也相等。。。就姑且认为长度为\(n'\)的两个串str[c:c+n'),str[d:d+n')相等。

那么我们字符串的比较就转化为了(rnk[c],rnk[c+m']),(rnk[d],rnk[d+m'])的二元组的比较。

我们也可以照猫画虎定义出二元组比较的规则:

二元组\((a,b),(c,d)\)比较大小

  1. 第一元素小的小
  2. 第一元素相同的情况下,第二元素小的小
  3. 两个元素都相同,则两个二元组相等。

那么我们一开始的时候,长度为\(1\),并不能拆成两个字符串。这种初始情况我们特殊处理一下,就让第\(i\)个二元组为\((str_i,i)\)这就意味着ASCII的大小作为第一元素比较。如果第一元素相同则位置作为第二元素比较。位置靠后的这个二元素排名也靠后。

tuple_sort(0xff); // 为什么是0xff?后面告诉你

然后我们就可以开始倍增了。

先设我们已经求出了长度为\(m\)\(n\)个子串之间的排名信息(越界在后面补"\0"),那么我们现在就要求长度为\(2m\)的排名信息

for(int m=1;m<n;m<<=1)
{
    // do sth.
}

我们每次手中拥有的就是rnk[i]表示子串str[i:i+m)在所有长度为\(m\)的子串中的排名,对应的我们也可以求出sa。而我们的目标就是要更新rnk[i]使他表示子串str[i:i+2m)在所有长度为\(2m\)的字符串中的排名。

很明显我们需要把str[i:i+2m)拆成str[i:i+m),str[i+m:i+2m),用rnk[i]作为第一元素,rnk[i+m]作为第二元素求出他们的排名。

使用快排,这样时间复杂度\(n\log_2n\),再算上一个倍增,总时间复杂度\(n\log_2^2n\)

感觉还可以。

但是我们需要优化一下。采用基数排序。

说那么高大上其实就是桶排。

由于这个桶排太过神奇,导致我也不知道他是怎么兼容二元组排序的。所以我只能告诉你他是怎么做的:

首先定义一个数组\(tp_i\)表示第二元素排名为\(i\)的长度为\(2m\)的串的起始位置。

有点绕,换种方法解释一次,起始位置为\(tp_i\)的长度为\(2m\)的串的第二元素在所有第二元素中的排名中为\(i\)

那么很明显,有些串的第二元素是一堆的"\0",这其实就意味着这个位置开始的后缀长度根本不足\(m\),使得他们没有第二元素。根据字符串大小关系的第二条定义,短的小,因此这些没有第二元素的串应该是最小的。而他们内部,由于都没有第二元素,所以就应该按照起始位置排序。

int cnt=0;
for(int i=n-m;i<n;++i) tp[cnt++]=i;

解决了这些没有第二元素的串之后,我们就考虑有第二元素的串了。

直接求好像很难算。但是我们比较一下

sa[i]:当前长度\(m\)下排名为\(i\)的串的起始位置

tp[i]:长度\(2m\)下第二元素排名为\(i\)的串的起始位置

稍加思索,可以发现tp[i]=sa[i]-m

什么意思?如果第二元素排名为\(i\)的串的起始位置是tp[i],那么他的第二元素的起始位置就是tp[i]+m,长度为\(m\),我们的sa[i]就表示长度为\(m\)的排名为\(i\)的串的起始位置。

所以是有等式\(tp_i+m=sa_i\)成立,可以解出tp[i]=sa[i]-m

但是并不是所有sa都是有用的。比如当前长度\(m=16\),然后sa[5]=1。虽然他的排名是5,但是他起始位置是1,是作为起始位置为-15的串的第二元素。而这个串不存在,所以不能更新

for(int i=0;i<n;++i) if(sa[i]>=m) tp[cnt++]=sa[i]-m 

为什么两次循环都是cnt++呢?可以发现,前面第一次循环是越短越小,因此我按照位置递增枚举。而第二次循环,他们都有第二元素,sa[i]里面\(i\)递增,因此第二元素的排名也递增,所以也是cnt++就好了。

然后我们第一元素和第二元素都齐全了,我们进行排序

tuple_sort(cnt);

现在我们就不得不来讲讲这个排序了

3. 基数排序

首先定义数组tax[i]表示桶

void tuple_sort(const int N)
{
    /// N: 表示桶的大小。是用来优化的
    // 因为一开始的时候排名最多就是0xff个(char也就最能存0xff个)
    for(int i=0;i<N;++i) tax[i]=0; // 清空
    for(int i=0;i<n;++i) tax[rnk[i]]++; // 统计第一元素排名为rnk[i]的串的个数
    for(int i=1;i<N;++i) tax[i]+=tax[i-1]; // 做前缀和
    // 这里做了前缀和之后tax[x]就表示排名小于等于x的串的个数有多少个。
    for(register int i=n-1;i>=0;--i) sa[--tax[rnk[tp[i]]]]=tp[i]; // @#$%^&*
}

详细解释一下最后一行。

我们可以发现,第二元素排名为i的串的第一元素排名为rnk[tp[i]]

很绕,换种方法解释一下。

对于一个串,起始位置为tp[i],那么他就满足第二元素排名为\(i\),那么这个起始位置为tp[i]的串的第一元素排名就是rnk[tp[i]](因为当前rnk储存的还是旧信息,尚未翻新)

把上面那句话顺序整理一下,就可以得到

第二元素排名为i的串的第一元素排名为rnk[tp[i]]

所以我们逆序枚举\(i\),可以保证第二元素的排名单调递减,也就是先被遍历到的第二元素一定更大。

那么他的第一元素排名为rnk[tp[i]],再根据前缀和的结果我们可以得出第一元素排名不超过rnk[tp[i]]的有tax[rnk[tp[i]]]个,这其中有一些第一元素相同,第一元素不同。

我们的i是倒序枚举的,也就是说在第一元素相同的情况下,我们当前的tp[i]作为起始位置的串第二元素会更大,啊不对应该是最大。因此他在所有长度为\(2m\)的串中的排名就应该是tax[rnk[tp[i]]]-1(我的代码喜欢从0开始用数组,如果你喜欢从1开始也可以,就。。。细节要改亿点点)

于是乎,这个起始位置为tp[i]的串的排名就是tax[rnk[tp[i]]]-1。又因为这是最大的一个,我们不能妨碍后面第一元素相同的串来找他的排名,因此排名小于等于rnk[tp[i]]的串的个数应该减一,所以就是--tax[rnk[tp[i]]]

好了我们现在得出了长度为\(2m\)sa数组。

照理说我们只需要rnk[sa[i]]=i就可以得到rnk数组了。

但是要注意,我们说上面那条式子成立,他的条件是rnk,sa表示的都是后缀的信息。显然我们现在表示的不是后缀的信息。

后缀信息有什么特别的呢?后缀的信息可以保证rnk,sa两两不同。但是我们当前没法保证,当前的串有可能是相同的。

那么。。。相同怎么办?

相同的串就姑且让他排名也相同吧。。。

我们可以发现到此为止我们有两个变量已经完成了他们的使命了,一个是数组tp,另一个是计数器cnt

在本轮循环结束之前,姑且让他们再贡献一下。

我们可以知道如果没有重复那么我们的rnk数组就应该这样获得

for(int i=0;i<n;++i) rnk[sa[i]]=i;

他跟下面的代码是等价的

cnt=0;
for(int i=0;i<n;++i) rnk[sa[i]]=cnt++;

现在有些串是相等的,那么对于相等的串我们就不能赋值cnt,而应该赋值cnt-1,与此同时cnt也不应该自增。

所以rnk[sa[i]]=(equal?) cnt-1:cnt++;

至于这个相等怎么判断,我们根据上面的知识,两个串比较大小可以变成两个二元组比较大小,这个二元组就是旧的rnk组成的二元组。而我们现在rnk正在更新,所以我们应该把旧的rnk保存下来。比较暴力的方法就是memcpy

而文雅一点的方法就是先把rnk和现在没用的tp都定义成指针类型,然后直接swap交换两个指针就可以了。

那么相等的判断就很简单了。如果排名为i的串和排名为i-1的串相等,当且仅当(tp[sa[i-1]]==tp[sa[i]] and tp[sa[i-1]+m]==tp[sa[i]+m])需要注意这里tp并不是原来的意思了,他已经变成了旧的rnk(至少我自己看回来都瞪了好久/捂脸)

上面的式子意思就是当二元组两个元素都相等的时候,就判定这两个串相等。

我们看见这里面调用了sa[i-1]这就意味着我们\(i\)最小值只能是1,所以i=0的情况要我们特殊处理。可以发现i=0并不会和前面的串重复,因此rnk[sa[0]]=(cnt=0)++;


一切解释完毕。

最后一步:把这些代码背下来。

// https://www.luogu.com.cn/problem/P3809
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
char str[1000001];
int n,sa[1000001],rnk[1000001];
int tp[1000001],tax[1000001],*tmp;
void tuple_sort(const int N)
{
	for(register int i=0;i<N;++i) tax[i]=0;
	for(register int i=0;i<n;++i) tax[rnk[i]]++;
	for(register int i=1;i<N;++i) tax[i]+=tax[i-1];
	for(register int i=n-1;i>=0;--i) sa[--tax[rnk[tp[i]]]]=tp[i];
}
int main()
{
	scanf("%s",str);
	n=strlen(str);
	for(register int i=0;i<n;++i) rnk[i]=str[i], tp[i]=i;
	// 因为下面排序使用tp[i]作为第二关键字,因此将需要的第二关键字传入tp[i]
	tuple_sort(0xff);
	for(register int m=1;m<n;m<<=1)
	{
		register int cnt=0;
		for(register int i=n-m;i<n;++i) tp[cnt++]=i;
		for(register int i=0;i<n;++i) if(sa[i]>=m) tp[cnt++]=sa[i]-m;
		tuple_sort(cnt);
		std::swap(rnk,tp);
		rnk[sa[0]]=(cnt=0)++;
		for(int i=1;i<n;++i) rnk[sa[i]]=(tp[sa[i-1]]==tp[sa[i]] and tp[sa[i-1]+m]==tp[sa[i]+m])? cnt-1:cnt++;
	}
	for(register int i=0;i<n;++i) printf("%d ",sa[i]+1);
	printf("\n");
	return 0;
}
posted @ 2022-04-16 16:11  IdanSuce  阅读(95)  评论(4编辑  收藏  举报