后缀数组
基本概念
给你一个字符串 \(s\),对于所有 \(s\) 的后缀,我们按照字典序排序。最后输出排在 \(i\) 的后缀是原串里的第几个。
\(sa_i\) 指字典序排名第 \(i\) 的是第几个前缀。
\(O(n^2\log n)\)
直接找出所有后缀,排序。
\(O(n\log^2 n)\)
我们需要用到倍增来实现。例如这样:
这个的思想大概是这样的:先用 \(s\) 所有长度为 \(1\) 的子串进行排序,得到 \(sa\) 数组和 \(x\) 数组。
然后开始倍增。每次使用长度为 \(\dfrac{w}{2}\) 的子串进行排序。具体地,用 \(x_{i}\) 和 \(x_{i+\frac{w}{2}}\) 作为排序的一二关键字进行排序。
于是就做完了,给个代码:
void get_sa(){
for(int i=1;i<=n;i++){
sa[i]=i;
x[i]=s[i];
}
for(int k=1;k<n;k<<=1){
sort(sa+1,sa+n+1,[&](int a,int b){
return x[a]==x[b]?x[a+k]<x[b+k]:x[a]<x[b];
});
memcpy(y,x,sizeof x);
int num=0;
for(int i=1;i<=n;i++){
x[sa[i]]=(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+k]==y[sa[i-1]+k])?num:++num;
}
}
}
这个代码已经可以通过部分平台的测试数据,但是为了保险,我们需要继续优化。
\(O(n\log n)\)
我们思考为什么会有两个 \(\log\),因为我们的 \(sort\) 太慢了。所以我们要考虑基数排序(这里插一句,如果写的常数太大会在一些平台上超时)。
我们考虑先做第一步,按照单个字符排序,这个是比较简单的,大概就是开个桶记录一下每个字符的数量,然后做个前缀和,最后倒着扫一遍得到 \(sa\) 数组的值。大概是这样:
代码:
for(int i=1;i<=n;i++)c[x[i]=s[i]]++;
for(int i=1;i<=m;i++)c[i]+=c[i-1];
for(int i=n;i;i--)sa[c[x[i]]--]=i;
为什么这样是对的。我们考虑做过前缀和的 \(c\) 数组代表什么。代表的是这个字母最靠后的可能位置,而最后一行的意思是把这个字母当前最靠后的可能位置的 \(sa\) 弄成这个位置。
然后我们开始倍增,考虑如何对第二关键字进行排序。因为我们会把第二关键字越界的位置的第二关键字设为 \(0\),事实上就是 \(i+k>n\) 的位置第二关键字为 \(0\)。
所以我们在对第二关键字排序时,把这些为 \(0\) 的放到最前面,事实上就是 \(n-k+1\sim n\) 的位置。
然后剩下的位置就按照原来的 \(sa\) 放到后面。比方说第 \(3\) 个位置的 \(sa\) 为 \(4\),然后当前的 \(k\) 为 \(1\),那么用到这个 \(4\) 的就是位置 \(3-k=2\)。然后对于放到后面(第二关键字不为 \(0\))的第 \(2\) 个数就是 \(4-k=3\)。
代码:
int num=0;
for(int i=n-k+1;i<=n;i++)y[++num]=i;
for(int i=1;i<=n;i++){
if(sa[i]>k){
y[++num]=sa[i]-k;
}
}
这里我们 \(y\) 数组的含义为:以第二关键字排序,排名为 \(i\) 的是第几个后缀。
然后我们重复一下一开始统计长度为 \(1\) 的子串时的操作。就是开个桶统计一下,然后做个前缀和。
代码:
for(int i=1;i<=m;i++)c[i]=0;
for(int i=1;i<=n;i++)c[x[i]]++;
for(int i=1;i<=m;i++)c[i]+=c[i-1];
for(int i=n;i;i--)sa[c[x[y[i]]]--]=y[i],y[i]=0;
这里解释一下数组含义:
\(y_i\) 表示以第二关键字排序,排名为 \(i\) 的是第几个后缀。
\(x_{y_i}\) 表示上述后缀按照第一关键字的排名。
\(c_{x_{y_i}}\) 表示小于等于上述排名的数量,也就是这个后缀的排名。
\(sa_{c_{x_{y_i}}}\) 的值为 \(y_i\) 表示上述排名对应的是第 \(y_i\) 个后缀。
这里倒着循环的是为了在第一关键字相同的情况下保持第二关键字的原始顺序。
然后我们的 \(y\) 就没有用了,但是 \(x\) 还是有用的。为了方便,我们交换 \(x,y\)。
接着把 \(sa_{x_1}\) 按照第一关键字的排名设成 \(1\),开始对 \(x\) 赋值。就是如果这个位置的排名和上一个位置的排名完全相同,\(x_{sa_i}\) 就是 \(x_{sa_{i-1}}\);否则是 \(x_{sa_{i-1}}+1\)。
最后就是如果已经区分出来排名了,就没有必要继续循环了。
代码:
swap(x,y);
x[sa[1]]=num=1;
for(int i=2;i<=n;i++){
x[sa[i]]=(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+k]==y[sa[i-1]+k])?num:++num;
}
if(num==n)break;
m=num;
完整代码:
void get_sa(){
for(int i=1;i<=n;i++)c[x[i]=s[i]]++;
for(int i=1;i<=m;i++)c[i]+=c[i-1];
for(int i=n;i;i--)sa[c[x[i]]--]=i;
//其实这里的x_i可以看成x_{y_i},y_i这里就是i
for(int k=1;k<=n;k<<=1){
int num=0;
for(int i=n-k+1;i<=n;i++)y[++num]=i;
for(int i=1;i<=n;i++){
if(sa[i]>k){
y[++num]=sa[i]-k;
}
}
for(int i=1;i<=m;i++)c[i]=0;
for(int i=1;i<=n;i++)c[x[i]]++;
for(int i=1;i<=m;i++)c[i]+=c[i-1];
for(int i=n;i;i--)sa[c[x[y[i]]]--]=y[i],y[i]=0;
swap(x,y);
x[sa[1]]=num=1;
for(int i=2;i<=n;i++){
x[sa[i]]=(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+k]==y[sa[i-1]+k])?num:++num;
}
if(num==n)break;
m=num;
}
}
但是,很多题只有 \(sa\) 是不够的,于是我们定义 \(hei_i\) 为排名 \(i\) 的后缀与排名 \(i-1\) 的后缀的最长公共前缀的长度。
为了求出这个长度,我们还需要一个 \(rk\) 数组,代表第 \(i\) 个后缀的排名。
注意这里有个结论,\(hei_i\ge hei_{i-1}+1\)。
然后就是循环求 \(hei\) 数组。代码:
void get_hei(){
int k=0;
for(int i=1;i<=n;i++)rk[sa[i]]=i;
for(int i=1;i<=n;i++){
if(rk[i]==1)continue;
if(k)k--;
int j=sa[rk[i]-1];
while(i+k<=n&&j+k<=n&&s[i+k]==s[j+k])k++;
hei[rk[i]]=k;
}
}
这个 \(j\) 代表前一个排名的后缀是第 \(j\) 个。
题目
后缀排序
板子题,直接看上面的东西即可。
字符加密
我们首先把字符串复制一份接到原来的字符串后面。
然后我们直接对其求一个 \(sa\) 数组。输出一下就做完了。
这里说一下为什么复制一遍是对的。因为考虑如果说在前 \(n\) 位没有排出名次(排出名次就没有后面了),那么有影响的其实是这个串的前面,也就是说影响排名的还是自己,所以这样是对的。
代码:
#include<bits/stdc++.h>
#define int long long
#define N 200005
using namespace std;
char s[N];
int n,m,x[N],y[N],c[N],sa[N];
void get_sa(){
for(int i=1;i<=n;i++)c[x[i]=s[i]]++;
for(int i=1;i<=m;i++)c[i]+=c[i-1];
for(int i=n;i;i--)sa[c[x[i]]--]=i;
for(int k=1;k<=n;k<<=1){
int num=0;
for(int i=n-k+1;i<=n;i++)y[++num]=i;
for(int i=1;i<=n;i++){
if(sa[i]>k){
y[++num]=sa[i]-k;
}
}
for(int i=1;i<=m;i++)c[i]=0;
for(int i=1;i<=n;i++)c[x[i]]++;
for(int i=1;i<=m;i++)c[i]+=c[i-1];
for(int i=n;i;i--)sa[c[x[y[i]]]--]=y[i],y[i]=0;
swap(x,y);
x[sa[1]]=num=1;
for(int i=2;i<=n;i++){
x[sa[i]]=(y[sa[i-1]]==y[sa[i]]&&y[sa[i-1]+k]==y[sa[i]+k])?num:++num;
}
if(num==n)break;
m=num;
}
}
signed main(){
cin>>s+1;
n=strlen(s+1);
for(int i=n+1;i<=n<<1;i++){
s[i]=s[i-n];
}
n<<=1;
m='z';
get_sa();
for(int i=1;i<=n;i++){
if(sa[i]<=n>>1)cout<<s[sa[i]+(n>>1)-1];
}
return 0;
}