后缀数组讲解及例题

时间复杂度:倍增求法,复杂度 \(O(nlogn)\)

首先把 \(s\) 的每个后缀字典序排序。
\(sa[i]:\) 排名第 \(i\) 位的是第几个后缀(起始下标)。
\(rk[i]:\)\(i\) 个(起始下标为 \(i\))的后缀的的排名。
\(height[i]:\) \(sa[i]\)\(sa[i-1]\) 的最长公共前缀。

\(height\) 数组的求法:
假设所有后缀都已经排好序了,求 \(Lcp(i,j)\)
有:(\(i,j,k\) 均为排名,不作证明) $$Lcp(i,j) = Lcp(j,i)$$

\[Lcp(i,i) = len(i) \]

\[Lcp(i,j)=min(Lcp(i,k), Lcp(k,j)),i<=k<=j \]

\[height[i]=Lcp(i-1,i) \]

定义 \(h(i)=height[rk[i]]\),即起始下标为 \(i\) 的后缀与它排名前一的后缀的最长公共前缀
有:

\[h(i)>=h(i - 1) - 1 \]

代码详细注释:

const int N = 1e6 + 10;
int n, m;
char s[N];
int sa[N], x[N], y[N], c[N], rk[N], height[N];
//c[i]每个关键字的个数
//sa[i] 排名第i位后缀编号 rank[i]编号为i的后缀的排名
//height[i] sa[i]与sa[i-1]的最长公共前缀
inline void get_sa(){
    for(int i = 1; i <= n; i ++) c[x[i] = s[i]] ++; //初始第一关键字的排名就设置为其ASCII码即可
    for(int i = 2; i <= m; i ++) c[i] += c[i - 1]; //统计前缀和,统计小于等于每个数的关键字有多少,注意只加到m
    for(int i = n; i; i --) sa[c[x[i]] --] = i; //从后往前求每个后缀的排名,数量需要减一,以第一关键字排序
    for(int k = 1; k <= n; k <<= 1){ //倍增计算sa
        int num = 0; //计算Y,排名从1开始指针指向0
        for (int i = n - k + 1; i <= n; i ++) y[++ num] = i; //后k个没有第二关键字,排名最小先分配排名
        for (int i = 1; i <= n; i ++ ){ //从小到大枚举第二关键字
            if (sa[i] <= k) continue; //前k个第一关键字不能作为某个后缀的第二关键字
            y[ ++ num] = sa[i] - k; 
            //sa为以第一关键字计算的排名,从小到大枚举排名,对应的下标其实是第 sa[i]-k 个后缀的第二关键字
        }
        for (int i = 1; i <= m; i ++ ) c[i] = 0; //清空关键字数量
        for (int i = 1; i <= n; i ++ ) c[x[i]] ++; //计算新的第一关键字每个排名有几个
        for (int i = 2; i <= m; i ++ ) c[i] += c[i - 1]; //计算前缀和
        for (int i = n; i; i -- ) sa[c[x[y[i]]] -- ] = y[i], y[i] = 0;
        //y[i]以第二关键字排序,排名为i的后缀的下标 x[y[i]]上述后缀按照第一关键字的排名
        //c[x[y[i]]] 表示小于等于上述排名的数量,也就是该后缀的排名
        //sa记录上述排名对应的下标为 y[i],从后往前枚举第二关键字的排名
        //使得第一关键字相同的后缀也可以依靠第二关键字区分
        swap(x, y);         
        //接下来要更新X数组,且要用到旧的X数组,Y数组接下里用不到,
        //则将两者交换,目的是将旧的X存到Y中,后面所有的Y实际就是旧的X
        //将当前的第一关键字和第二关键字当做下一轮的第一关键字,sa中存的就是按照双关键字排序的结果。
        x[sa[1]] = 1, num = 1; //sa[1]对于的后缀新 X的排名也为1
        for (int i = 2; i <= n; i ++ )
            //如果新排名为i的后缀和新排名为i-1的后缀的第一关键字排名相同(前一个 == )
            //并且它们的第二关键字排名也相同(后一个 == ),那么两个后缀的新X排名相同,否则不同
            x[sa[i]] = (y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k]) ? num : ++ num;
        if (num == n) break; //如果已经完全区分出n个后缀了,则可以结束循环
        m = num; //更新离散化后的rank范围
    }   
}
//height[i] 表示排名为i和排名为i-1的后缀的最长公共前缀
inline void get_height(){
    for(int i = 1; i <= n; i ++) rk[sa[i]] = i; //排名为i的字符串(sa[i])排名为i
    for(int i = 1, k = 0; i <= n; i ++){
        if(rk[i] == 1) continue; //排名为1的height不用计算
        //设h[i]表示height[rk[i]],即位于第 i个的后缀与排名在它前一个的后缀的最长公共前缀
        if(k) k --; //由于h[i]>=h[i-1]-1,所以从 k-1开始枚举
        int j = sa[rk[i] - 1]; //排名在i前一个的后缀的下标
        while(i + k <= n && j + k <= n && s[i + k] == s[j + k]) k ++; //如果相等,则最长公共前缀+1
        height[rk[i]] = k; //更新height
    }
}

例题:

\(1.\)品酒大会(\(NOI2015\))

题意:\(n\) 杯酒,每杯酒有一个标签为小写英文字母,(\(str(l, r)\)\(l\)\(r\) 长度为 \(r -l + 1\) 的字符串,貌似没啥用)。\(p\)\(q\)\(r\) 相似”为从 \(p\), \(q\) 开始的长度为 \(r\) 的前缀相同,每一杯酒有美味度,对于每一个 \(r\)\(1\)\(n-1\) 选出两杯 \(r\) 相似的酒,求 \(r\) 相似对数和 \(r\) 相似中两杯美味度相乘最大值。

思路:用后缀数组得到一个后缀排序的结果,并且 \(height\) 数组可以得到相邻两个后缀最长公共前缀的长度,并且有 *性质 :\(i, j\) 最长前缀长度等于之间的相邻两两最长前缀最小值。对于每个固定的 \(r\) 来说,利用所有 \(height[i] < r\)\(i\) 将数列分成若干段,这样所有 \(r\) 相似的后缀(酒)不可能出现在不同的段里(由于上面的 *性质),段内任意两杯酒一定 \(r\) 相似,由此可以解决第一问求个数(组合数学)。对于第二问统计最大值,三种情况可以化为两种情况(考虑正负)所以只需要在每个段里维护最大次大或最小次小就可以。现在对于固定的 \(r\) 都可以解决。现在需要解决的是所有 \(r\)。考虑用什么顺序遍历比较好处理,从大到小,最开始段多,后来段变少,合并段较容易维护所需信息。区间合并使用并查集 \(O(n)\)
总复杂度 \(O(n) + O(nlogn)\),并查集加后缀数组。

#include <map>
#include <cmath>
#include <queue>
#include <vector>
#include <cstdio>
#include <string>
#include <cstring>
#include <iostream>
#include <algorithm>
#define int long long
using namespace std;

template <class T> inline void read(T &x){
    x = 0; register char c = getchar(); register bool f = 0;
    while (!isdigit(c)) f ^= c == '-', c = getchar();
    while (isdigit(c)) x = x * 10 + c - '0', c = getchar();
    if (f) x = -x;
}

template <class T> inline void print(T x){
    if (x < 0) putchar('-'), x = -x;
    if (x > 9) print(x / 10);
    putchar('0' + x % 10);
}

#define x first
#define y second
#define PII pair<int, int>

const int N = 3e5 + 10;
const int inf = 2e18;
int n, m; char s[N];
int sa[N], x[N], y[N], c[N], rk[N], height[N];
int w[N], p[N], sz[N];//并查集数组,并查集元素个数
int max1[N], max2[N], min1[N], min2[N];//最大次大最小次小
vector<int> hs[N];//记录每种值的height有哪些
PII ans[N];

//后缀数组
inline void get_sa(){
    for(int i = 1; i <= n; i ++) c[x[i] = s[i]] ++;
    for(int i = 2; i <= m; i ++) c[i] += c[i - 1];
    for(int i = n; i; i --) sa[c[x[i]] -- ] = i;
    for(int k = 1; k <= n; k <<= 1){
        int num = 0;
        for(int i = n - k + 1; i <= n; i ++) y[++ num] = i;
        for(int i = 1; i <= n; i ++)
            if(sa[i] > k)
                y[ ++ num] = sa[i] - k;
        for(int i = 1; i <= m; i ++) c[i] = 0;
        for(int i = 1; i <= n; i ++) c[x[i]] ++ ;
        for(int i = 2; i <= m; i ++) c[i] += c[i - 1];
        for(int i = n; i; i --) sa[c[x[y[i]]] -- ] = y[i], y[i] = 0;
        swap(x, y);
        x[sa[1]] = 1, num = 1;
        for (int i = 2; i <= n; i ++)
            x[sa[i]] = (y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k]) ? num : ++ num;
        if(num == n) break;
        m = num;
    }
}


inline void get_height(){
    for (int i = 1; i <= n; i ++) rk[sa[i]] = i;
    for (int i = 1, k = 0; i <= n; i ++){
        if(rk[i] == 1) continue;
        if(k) k -- ;
        int j = sa[rk[i] - 1];
        while(i + k <= n && j + k <= n && s[i + k] == s[j + k]) k ++ ;
        height[rk[i]] = k;
    }
}

inline int find(int x){
    if(p[x] != x) p[x] = find(p[x]);
    return p[x];
}

inline int get(int x){
    return x * (x - 1ll) / 2; //即组合数C²x,x选2
}

inline PII calc(int r){
    static int cnt = 0, maxv = -inf;
    for(auto x : hs[r]){//将所有值等于r的height合并
        int a = find(x - 1), b = find(x); //合并x和x-1
        cnt -= get(sz[a]) + get(sz[b]); //减去原来两个区间的cnt
        p[a] = b; //合并 
        sz[b] += sz[a];
        cnt += get(sz[b]);//维护大小
        if(max1[a] >= max1[b]){
            max2[b] = max(max1[b], max2[b]);
            max1[b] = max1[a];
        }else if(max1[a] > max2[b]) max2[b] = max1[a];
        if(min1[a] <= min1[b]){
            min2[b] = min(min1[b], min2[a]);
            min1[b] = min1[a];
        }else if(min1[a] < min2[b]) min2[b] = min1[a];
        maxv = max(maxv, max(max1[b] * max2[b], min1[b] * min2[b]));
    }
    if(maxv == -inf) return {cnt, 0};
    return {cnt, maxv};
}

signed main(){
    read(n), m = 122;
    scanf("%s", s + 1);
    for(int i = 1; i <= n; i ++) read(w[i]);
    get_sa();
    get_height();
    for(int i = 2; i <= n; i ++) hs[height[i]].push_back(i);//把不同的height分类
    for(int i = 1; i <= n; i ++){
        p[i] = i; sz[i] = 1;
        max1[i] = min1[i] = w[sa[i]];
        max2[i] = -inf, min2[i] = inf;
    }//初始化并查集
    for(int i = n - 1; i >= 0; i --) ans[i] = calc(i);
    for(int i = 0; i < n; i ++) printf("%lld %lld\n", ans[i].x, ans[i].y);
    return 0;
}

\(2.\)生成魔咒(\(SDOI2016\))

前言:用后缀数组求字符串不同子串的数量。(静态问题)确定 \(i\) 后枚举第 \(i\) 个后缀的所有前缀。所有后缀的前缀集合就是所有子串的集合。(前置知识:P2408 不同子串个数 又水一道紫题)

\[ans = \frac {n(n+1)} 2- \sum_{i=1}^{n}height[rk[i]] \]

只需要在后缀数组板子上加上:

inline int solve(){
    int ans = 1ll * n * (n + 1) / 2;
    for(int i = 1; i <= n; i ++) ans -= height[rk[i]];
    return ans;
}

思路:来到动态,若每次往后添加一个字符,\(height\) 数组将会无法维护,将会影响前面所有的后缀,考虑将整个序列颠倒过来,边删除边维护 \(height\) 数组,动态维护,为了删除用双链表维护,对于\(i,j,k\) 若删除 \(j\)\(height[k] = min(height[j],height[k])\) 可以 \(O(1)\) 实现。总的来说,将字符串翻转,将询问序列翻转,这样就是每次删除一个后缀,


#include<map>
#include<cmath>
#include<queue>
#include<vector>
#include<cstdio>
#include<string>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<unordered_map>
#define int long long
using namespace std;

template <class T> inline void read(T &x){
    x = 0; register char c = getchar(); register bool f = 0;
    while (!isdigit(c)) f ^= c == '-', c = getchar();
    while (isdigit(c)) x = x * 10 + c - '0', c = getchar();
    if (f) x = -x;
}

template <class T> inline void print(T x){
    if (x < 0) putchar('-'), x = -x;
    if (x > 9) print(x / 10);
    putchar('0' + x % 10);
}

const int N = 1e5 + 10;
int n, m, s[N];
int sa[N], x[N], y[N], c[N], rk[N], height[N];
int u[N], d[N], ans[N];

inline int get(int x){
    static unordered_map<int, int> hash;
    if(hash.count(x) == 0) hash[x] = ++ m;
    return hash[x];
}

inline void get_sa(){
    for(int i = 1; i <= n; i ++) c[x[i] = s[i]] ++ ;
    for(int i = 2; i <= m; i ++) c[i] += c[i - 1];
    for(int i = n; i; i --) sa[c[x[i]] --] = i;
    for(int k = 1; k <= n; k <<= 1){
        int num = 0;
        for(int i = n - k + 1; i <= n; i ++) y[++ num] = i;
        for(int i = 1; i <= n; i ++)
            if(sa[i] > k)
                y[++ num] = sa[i] - k;
        for(int i = 1; i <= m ; i ++) c[i] = 0;
        for(int i = 1; i <= n; i ++) c[x[i]] ++ ;
        for(int i = 2; i <= m; i ++) c[i] += c[i - 1];
        for(int i = n; i; i --) sa[c[x[y[i]]] --] = y[i], y[i] = 0;
        swap(x, y);
        x[sa[1]] = 1, num = 1;
        for (int i = 2; i <= n; i ++)
            x[sa[i]] = (y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k]) ? num : ++ num;
        if(num == n) break;
        m = num;
    }
}

inline void get_height(){
    for(int i = 1; i <= n; i ++) rk[sa[i]] = i;
    for(int i = 1, k = 0; i <= n; i ++){
        if(rk[i] == 1) continue;
        if(k) k -- ;
        int j = sa[rk[i] - 1];
        while(i + k <= n && j + k <= n && s[i + k] == s[j + k]) k ++;
        height[rk[i]] = k;
    } 
}

signed main(){
    read(n);
    for(int i = n; i; i --) read(s[i]), s[i] = get(s[i]);
    get_sa();
    get_height();
    int res = 0;
    for(int i = 1; i <= n; i ++){
        res += n - sa[i] + 1 - height[i];
        u[i] = i - 1, d[i] = i + 1;
    }
    d[0] = 1, u[n + 1] = n;
    for(int i = 1; i <= n; i ++){
        ans[i] = res;
        int k = rk[i], j = d[k];
        res -= n - sa[k] + 1 - height[k];
        res -= n - sa[j] + 1 - height[j];
        height[j] = min(height[j], height[k]);
        res += n - sa[j] + 1 - height[j];
        d[u[k]] = d[k], u[d[k]] = u[k];
    }
    for(int i = n; i; i --) print(ans[i]), puts("");
    return 0;
}

\(3.\)字符加密(\(JSOI2007\))

#include <map>
#include <cmath>
#include <queue>
#include <vector>
#include <cstdio>
#include <string>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;

template <class T> inline void read(T &x){
    x = 0; register char c = getchar(); register bool f = 0;
    while (!isdigit(c)) f ^= c == '-', c = getchar();
    while (isdigit(c)) x = x * 10 + c - '0', c = getchar();
    if (f) x = -x;
}

template <class T> inline void print(T x){
    if (x < 0) putchar('-'), x = -x;
    if (x > 9) print(x / 10);
    putchar('0' + x % 10);
}

const int N = 1e6 + 10;
int n, m, t;
char s[N];
int sa[N], x[N], y[N], c[N], rk[N], height[N];

inline void get_sa(){
    for(int i = 1; i <= n; i ++) c[x[i] = s[i]] ++;
    for(int i = 2; i <= m; i ++) c[i] += c[i - 1];
    for(int i = n; i; i --) sa[c[x[i]] --] = i;
    for(int k = 1; k <= n; k <<= 1){
        int num = 0;
        for(int i = n - k + 1; i <= n; i ++) y[++ num] = i;
        for (int i = 1; i <= n; i ++ )
            if (sa[i] > k)
                y[ ++ num] = sa[i] - k;

        for (int i = 1; i <= m; i ++ ) c[i] = 0;
        for (int i = 1; i <= n; i ++ ) c[x[i]] ++ ;
        for (int i = 2; i <= m; i ++ ) c[i] += c[i - 1];
        for (int i = n; i; i -- ) sa[c[x[y[i]]] -- ] = y[i], y[i] = 0;
        swap(x, y);
        x[sa[1]] = 1, num = 1;
        for (int i = 2; i <= n; i ++ )
            x[sa[i]] = (y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k]) ? num : ++ num;
        if (num == n) break;
        m = num;
    }   for(int i = 1; i <= n; i ++) c[x[i] = s[i]] ++;
}

inline void get_height(){
    for(int i = 1; i <= n; i ++) rk[sa[i]] = i;
    for(int i = 1, k = 0; i <= n; i ++){
        if(rk[i] == 1) continue;//height[1] = 0
        if(k) k --;
        int j = sa[rk[i] - 1];
        while(i + k <= n && j + k <= n && s[i + k] == s[j + k]) k ++;
        height[rk[i]] = k;
    }
}

signed main(){
    scanf("%s", s + 1);
    t = strlen(s + 1), m = 300, n = t << 1;
    for(int i = t + 1; i <= n; i ++) s[i] = s[i - t];
    get_sa();
    get_height();
    for(int i = 1; i <= n; i ++) if(sa[i] <=  t) printf("%c", s[(sa[i] + t - 1)]);
    return 0;
}

posted @ 2022-06-07 08:19  Altwilio  阅读(88)  评论(0编辑  收藏  举报