「模板」后缀数组

经过本蒟蒻大约两天的努力,总算是把后缀数组的模板写出来了 然而 ZXY 大佬已经 AK 了

我们不能和这种神仙比,还是自己比比就好了...

我的后缀数组实现有两个版本,其实他们的本质区别就在于基数排序的实现的部分。

后面会加入求 height[] 的代码,不过现在还是算了吧...

用伪链表实现基数排序可能好理解一些,但是容易因为内存访问过于跳跃而被卡常。

而用经典基数排序写法实现的代码,可能比较难理解,这个自己想想就好办了...

伪链表实现基数排序

#include<cstdio>
#include<cstring>

#define rep(i,__l,__r) for(signed i=(__l),i##_end_=(__r);i<=i##_end_;++i)
#define fep(i,__l,__r) for(signed i=(__l),i##_end_=(__r);i>=i##_end_;--i)
#define writc(a,b) fwrit(a),putchar(b)
#define mp(a,b) make_pair(a,b)
#define ft first
#define sd second
#define LL long long
#define ull unsigned long long
#define uint unsigned int
#define pii pair< int,int >
#define Endl putchar('\n')
// #define FILEOI
// #define int long long
// #define int unsigned
// #define int unsigned long long

#ifdef FILEOI
# define MAXBUFFERSIZE 500000
    inline char fgetc(){
        static char buf[MAXBUFFERSIZE+5],*p1=buf,*p2=buf;
        return p1==p2&&(p2=(p1=buf)+fread(buf,1,MAXBUFFERSIZE,stdin),p1==p2)?EOF:*p1++;
    }
# undef MAXBUFFERSIZE
# define cg (c=fgetc())
#else
# define cg (c=getchar())
#endif
template<class T>inline void qread(T& x){
    char c;bool f=0;
    while(cg<'0'||'9'<c)f|=(c=='-');
    for(x=(c^48);'0'<=cg&&c<='9';x=(x<<1)+(x<<3)+(c^48));
    if(f)x=-x;
}
inline int qread(){
    int x=0;char c;bool f=0;
    while(cg<'0'||'9'<c)f|=(c=='-');
    for(x=(c^48);'0'<=cg&&c<='9';x=(x<<1)+(x<<3)+(c^48));
    return f?-x:x;
}
// template<class T,class... Args>inline void qread(T& x,Args&... args){qread(x),qread(args...);}
template<class T>inline T Max(const T x,const T y){return x>y?x:y;}
template<class T>inline T Min(const T x,const T y){return x<y?x:y;}
template<class T>inline T fab(const T x){return x>0?x:-x;}
inline int gcd(const int a,const int b){return b?gcd(b,a%b):a;}
inline void getInv(int inv[],const int lim,const int MOD){
    inv[0]=inv[1]=1;for(int i=2;i<=lim;++i)inv[i]=1ll*inv[MOD%i]*(MOD-MOD/i)%MOD;
}
template<class T>void fwrit(const T x){
    if(x<0)return (void)(putchar('-'),fwrit(-x));
    if(x>9)fwrit(x/10);
    putchar(x%10^48);
}
inline LL mulMod(const LL a,const LL b,const LL mod){//long long multiplie_mod
    return ((a*b-(LL)((long double)a/mod*b+1e-8)*mod)%mod+mod)%mod;
}

const int MAXN=1e6;

char s[MAXN+5];
int sa[MAXN+5],rnk[(MAXN<<1)+5][2],now;
int n,siz;//siz : 当前桶用到的部分

inline void Init(){
    scanf("%s",s+1);
    n=strlen(s+1);
    s[++n]='\0';
}

int Q[MAXN+5],nxt[MAXN+5],tail[MAXN+5],head[MAXN+5],qcnt;
inline void Push(const int ind,const int val){
    Q[++qcnt]=val;
    nxt[qcnt]=0;
    if(!head[ind])head[ind]=qcnt;
    else nxt[tail[ind]]=qcnt;
    tail[ind]=qcnt;
}
inline void Clear(){qcnt=0;}

inline void Getsa(){
    /*------------------------- 预处理 SA_1 -------------------------*/
    rep(i,1,n)Push((int)s[i],i);
    for(int i=0,x=0;i<=256;++i)if(tail[i]){
        ++siz;
        while(head[i]){
            sa[++x]=Q[head[i]];
            rnk[Q[head[i]]][now]=siz;
            head[i]=nxt[head[i]];
        }
    }
    Clear();
    // puts("------sa_1 ans rnk_1------");
    // rep(i,1,n)printf("%d ",sa[i]);Endl;
    // rep(i,1,n)printf("%d ",rnk[i][now]);Endl;
    /*-------------------------     END     -------------------------*/

    /*------------------------- 倍增算 SA_i -------------------------*/
    for(int k=1;k<n;k<<=1){//使用倍增思想解决
        for(int i=n-k+1;i<=n;++i)//先处理掉那些没有后半部分的
            Push(rnk[i][now],i);
        for(int i=1;i<=n;++i)if(sa[i]>k)//按照前半部分放进桶, 如果没有前半部分则不管
            Push(rnk[sa[i]-k][now],sa[i]-k);
        for(int i=0,x=0;i<=siz;++i)//将队列里的东西放到 sa 数组里面去
            while(head[i]){
                sa[++x]=Q[head[i]];
                head[i]=nxt[head[i]];
            }
        siz=0;//准备更新最大的 rnk 值
        for(int i=1;i<=n;++i){
            if(rnk[sa[i-1]][now]!=rnk[sa[i]][now] || rnk[sa[i-1]+k][now]!=rnk[sa[i]+k][now])++siz;
            //前一半和前一半比,后一半和后一半比
            rnk[sa[i]][now^1]=siz;
        }
        now^=1;Clear();//滚动数组 + 清空队列
        if(siz==n)break;//核心剪枝...
    }
    /*-------------------------     END     -------------------------*/
}

signed main(){
#ifdef FILEOI
    freopen("file.in","r",stdin);
    freopen("file.out","w",stdout);
#endif
    Init();
    Getsa();
    //第一个一定是最小的, 因此不管
    rep(i,2,n)printf("%d ",sa[i]);Endl;
    return 0;
}

经典写法实现基数排序

#include<cstdio>
#include<cstring>

#define rep(i,__l,__r) for(signed i=(__l),i##_end_=(__r);i<=i##_end_;++i)
#define fep(i,__l,__r) for(signed i=(__l),i##_end_=(__r);i>=i##_end_;--i)
#define writc(a,b) fwrit(a),putchar(b)
#define mp(a,b) make_pair(a,b)
#define ft first
#define sd second
#define LL long long
#define ull unsigned long long
#define uint unsigned int
#define pii pair< int,int >
#define Endl putchar('\n')
// #define FILEOI
// #define int long long
// #define int unsigned
// #define int unsigned long long

#ifdef FILEOI
# define MAXBUFFERSIZE 500000
    inline char fgetc(){
        static char buf[MAXBUFFERSIZE+5],*p1=buf,*p2=buf;
        return p1==p2&&(p2=(p1=buf)+fread(buf,1,MAXBUFFERSIZE,stdin),p1==p2)?EOF:*p1++;
    }
# undef MAXBUFFERSIZE
# define cg (c=fgetc())
#else
# define cg (c=getchar())
#endif
template<class T>inline void qread(T& x){
    char c;bool f=0;
    while(cg<'0'||'9'<c)f|=(c=='-');
    for(x=(c^48);'0'<=cg&&c<='9';x=(x<<1)+(x<<3)+(c^48));
    if(f)x=-x;
}
inline int qread(){
    int x=0;char c;bool f=0;
    while(cg<'0'||'9'<c)f|=(c=='-');
    for(x=(c^48);'0'<=cg&&c<='9';x=(x<<1)+(x<<3)+(c^48));
    return f?-x:x;
}
// template<class T,class... Args>inline void qread(T& x,Args&... args){qread(x),qread(args...);}
template<class T>inline T Max(const T x,const T y){return x>y?x:y;}
template<class T>inline T Min(const T x,const T y){return x<y?x:y;}
template<class T>inline T fab(const T x){return x>0?x:-x;}
inline int gcd(const int a,const int b){return b?gcd(b,a%b):a;}
inline void getInv(int inv[],const int lim,const int MOD){
    inv[0]=inv[1]=1;for(int i=2;i<=lim;++i)inv[i]=1ll*inv[MOD%i]*(MOD-MOD/i)%MOD;
}
template<class T>void fwrit(const T x){
    if(x<0)return (void)(putchar('-'),fwrit(-x));
    if(x>9)fwrit(x/10);
    putchar(x%10^48);
}
inline LL mulMod(const LL a,const LL b,const LL mod){//long long multiplie_mod
    return ((a*b-(LL)((long double)a/mod*b+1e-8)*mod)%mod+mod)%mod;
}

const int MAXN=1e6;

char s[MAXN+5];
int sa[MAXN+5],rnk[MAXN+5][2],now;
int n,siz;//siz : 当前桶用到的部分

inline void Init(){
    scanf("%s",s+1);
    n=strlen(s+1);
    s[++n]='\0';
}

int tax[MAXN+5],tp[MAXN+5];

inline void Getsa(){
    rep(i,1,n)++tax[rnk[i][now]=(int)s[i]];
    rep(i,1,siz=256)tax[i]+=tax[i-1];
    fep(i,n,1)sa[tax[rnk[i][now]]--]=i;
    // printf("sa_1 and rnk_1\n");
    // rep(i,1,n)printf("%d ",sa[i]);Endl;
    // rep(i,1,n)printf("%d ",rnk[i][now]);Endl;
    for(int k=1,x=0;k<=n;k<<=1,x=0){
        rep(i,n-k+1,n)tp[++x]=i;
        rep(i,1,n)if(sa[i]>k)tp[++x]=sa[i]-k;
        // puts("____________________________________________________________");
        // printf("tp:");rep(i,1,n)printf("%d ",tp[i]);Endl;
        rep(i,0,siz)tax[i]=0;
        rep(i,1,n)++tax[rnk[i][now]];
        rep(i,1,siz)tax[i]+=tax[i-1];
        // printf("tax:");rep(i,0,siz)writc(tax[i],' ');Endl;
        fep(i,n,1){
            // printf("Now tp[%d] == %d, rnk[%d][%d] == %d\n",i,tp[i],tp[i],now,rnk[tp[i]][now]);
            // printf("This is the %d-th\n",tax[rnk[tp[i]][now]]);
            sa[tax[rnk[tp[i]][now]]--]=tp[i];
        }
        rnk[sa[1]][now^1]=siz=1;
        rep(i,2,n){
            // printf("Now compare %d and %d\n",sa[i-1],sa[i]);
            // printf("Compare (%d,%d) or (%d,%d)\n",rnk[sa[i-1]][now],rnk[sa[i]][now],rnk[sa[i-1]+k][now],rnk[sa[i]+k][now]);
            rnk[sa[i]][now^1]=(rnk[sa[i-1]][now]==rnk[sa[i]][now] && rnk[sa[i-1]+k][now]==rnk[sa[i]+k][now])?siz:++siz;
        }
        now^=1;
        // printf("sa_%d and rnk_%d\n",k<<1,k<<1);
        // rep(i,1,n)printf("%d ",sa[i]);Endl;
        // rep(i,1,n)printf("%d ",rnk[i][now]);Endl;
        if(siz==n)break;
    }
}

signed main(){
#ifdef FILEOI
    freopen("file.in","r",stdin);
    freopen("file.out","w",stdout);
#endif
    Init();
    Getsa();
    //第一个一定是最小的, 因此不管
    rep(i,2,n)printf("%d ",sa[i]);Endl;
    return 0;
}

height 数组

\(\mathcal O(N\log N\log N)\) 复杂度

使用 \(\texttt{ST}\) 表求 height 数组,可能效率有点低,但是比较直观。

#include<cstdio>
#include<cstring>

#define rep(i,__l,__r) for(signed i=(__l),i##_end_=(__r);i<=i##_end_;++i)
#define fep(i,__l,__r) for(signed i=(__l),i##_end_=(__r);i>=i##_end_;--i)
#define writc(a,b) fwrit(a),putchar(b)
#define mp(a,b) make_pair(a,b)
#define ft first
#define sd second
#define LL long long
#define ull unsigned long long
#define uint unsigned int
#define pii pair< int,int >
#define Endl putchar('\n')
// #define FILEOI
// #define int long long
#define int unsigned
// #define int unsigned long long

#ifdef FILEOI
# define MAXBUFFERSIZE 500000
    inline char fgetc(){
        static char buf[MAXBUFFERSIZE+5],*p1=buf,*p2=buf;
        return p1==p2&&(p2=(p1=buf)+fread(buf,1,MAXBUFFERSIZE,stdin),p1==p2)?EOF:*p1++;
    }
# undef MAXBUFFERSIZE
# define cg (c=fgetc())
#else
# define cg (c=getchar())
#endif
template<class T>inline void qread(T& x){
    char c;bool f=0;
    while(cg<'0'||'9'<c)f|=(c=='-');
    for(x=(c^48);'0'<=cg&&c<='9';x=(x<<1)+(x<<3)+(c^48));
    if(f)x=-x;
}
inline int qread(){
    int x=0;char c;bool f=0;
    while(cg<'0'||'9'<c)f|=(c=='-');
    for(x=(c^48);'0'<=cg&&c<='9';x=(x<<1)+(x<<3)+(c^48));
    return f?-x:x;
}
// template<class T,class... Args>inline void qread(T& x,Args&... args){qread(x),qread(args...);}
template<class T>inline T Max(const T x,const T y){return x>y?x:y;}
template<class T>inline T Min(const T x,const T y){return x<y?x:y;}
template<class T>inline T fab(const T x){return x>0?x:-x;}
inline int gcd(const int a,const int b){return b?gcd(b,a%b):a;}
inline void getInv(int inv[],const int lim,const int MOD){
    inv[0]=inv[1]=1;for(int i=2;i<=lim;++i)inv[i]=1ll*inv[MOD%i]*(MOD-MOD/i)%MOD;
}
template<class T>void fwrit(const T x){
    if(x<0)return (void)(putchar('-'),fwrit(-x));
    if(x>9)fwrit(x/10);
    putchar(x%10^48);
}
inline LL mulMod(const LL a,const LL b,const LL mod){//long long multiplie_mod
    return ((a*b-(LL)((long double)a/mod*b+1e-8)*mod)%mod+mod)%mod;
}

const int MAXN=50000;
const int logMAXN=15;
const int INF=(1<<30)-1;

char s[MAXN+5];
int sa[MAXN+5],rnk[MAXN+5][2],h[MAXN+5][2],now,n;
int tax[MAXN+5],tp[MAXN+5],siz;
int T;

int ST[MAXN+5][logMAXN+5];

inline void Init(){
    scanf("%s",s+1);
    n=strlen(s+1);
    s[++n]='\0';
    rep(i,0,MAXN)tax[i]=0;
    rep(i,0,MAXN)h[i][0]=h[i][1]=0;
    rep(i,0,MAXN)rep(j,0,logMAXN)ST[i][j]=0;
}

inline int Ask(const int l,const int r){
    int ret=INF,p=Min(l,r)+1,len=fab(r-l),i=0;
    while(len){
        if(len&1){
            ret=Min(ret,ST[p][i]);
            p+=(1<<i);
        }
        len>>=1,++i;
    }
    return ret;
}

inline void Getsa(){
    rep(i,1,n)++tax[rnk[i][now]=s[i]];
    rep(i,1,siz=256)tax[i]+=tax[i-1];
    fep(i,n,1)sa[tax[rnk[i][now]]--]=i;
    for(int k=1,x=0;k<n;k<<=1,x=0){
        rep(i,n-k+1,n)tp[++x]=i;
        rep(i,1,n)if(sa[i]>k)tp[++x]=sa[i]-k;
        rep(i,0,siz)tax[i]=0;
        rep(i,1,n)++tax[rnk[i][now]];
        rep(i,1,siz)tax[i]+=tax[i-1];
        fep(i,n,1)sa[tax[rnk[tp[i]][now]]--]=tp[i];
        siz=0;
        rep(i,1,n){++siz;
            if(i!=1 && rnk[sa[i-1]][now]==rnk[sa[i]][now] && rnk[sa[i-1]+k][now]==rnk[sa[i]+k][now])--siz;
            else if(siz==1)h[siz][now^1]=1;
            else if(rnk[sa[i-1]][now]!=rnk[sa[i]][now])h[siz][now^1]=h[rnk[sa[i]][now]][now];
            else h[siz][now^1]=k+Ask(rnk[sa[i-1]+k][now],rnk[sa[i]+k][now]);
            rnk[sa[i]][now^1]=siz;
        }
        now^=1;
        if(siz==n)break;
        rep(i,2,siz)ST[i][0]=h[i][now];
        for(int j=1,L=1;j<=logMAXN;L<<=1,++j)for(int i=2;i+L<=siz;++i)
            ST[i][j]=Min(ST[i][j-1],ST[i+L][j-1]);
    }
}

inline void Solve(){
    int ans=n*(n-1)/2;
    for(int i=3;i<=n;++i)ans-=h[i][now];
    writc(ans,'\n');
}

signed main(){
#ifdef FILEOI
    freopen("file.in","r",stdin);
    freopen("file.out","w",stdout);
#endif
    T=qread();
    while(T--){
        Init();
        Getsa();
        Solve();
    }
    return 0;
}

\(\mathcal O(N)\) 时间复杂度

只是求 height 数组的时间复杂度为 \(\mathcal O(N)\),但是由于求 sa 数组的时间复杂度为 \(\mathcal O(N\log N)\),所以总体复杂度仍为 \(\mathcal O(N\log N)\)

#include<cstdio>
#include<cstring>

#define rep(i,__l,__r) for(signed i=(__l),i##_end_=(__r);i<=i##_end_;++i)
#define fep(i,__l,__r) for(signed i=(__l),i##_end_=(__r);i>=i##_end_;--i)
#define writc(a,b) fwrit(a),putchar(b)
#define mp(a,b) make_pair(a,b)
#define ft first
#define sd second
#define LL long long
#define ull unsigned long long
#define uint unsigned int
#define pii pair< int,int >
#define Endl putchar('\n')
// #define FILEOI
// #define int long long
#define int unsigned
// #define int unsigned long long

#ifdef FILEOI
# define MAXBUFFERSIZE 500000
    inline char fgetc(){
        static char buf[MAXBUFFERSIZE+5],*p1=buf,*p2=buf;
        return p1==p2&&(p2=(p1=buf)+fread(buf,1,MAXBUFFERSIZE,stdin),p1==p2)?EOF:*p1++;
    }
# undef MAXBUFFERSIZE
# define cg (c=fgetc())
#else
# define cg (c=getchar())
#endif
template<class T>inline void qread(T& x){
    char c;bool f=0;
    while(cg<'0'||'9'<c)f|=(c=='-');
    for(x=(c^48);'0'<=cg&&c<='9';x=(x<<1)+(x<<3)+(c^48));
    if(f)x=-x;
}
inline int qread(){
    int x=0;char c;bool f=0;
    while(cg<'0'||'9'<c)f|=(c=='-');
    for(x=(c^48);'0'<=cg&&c<='9';x=(x<<1)+(x<<3)+(c^48));
    return f?-x:x;
}
// template<class T,class... Args>inline void qread(T& x,Args&... args){qread(x),qread(args...);}
template<class T>inline T Max(const T x,const T y){return x>y?x:y;}
template<class T>inline T Min(const T x,const T y){return x<y?x:y;}
template<class T>inline T fab(const T x){return x>0?x:-x;}
inline int gcd(const int a,const int b){return b?gcd(b,a%b):a;}
inline void getInv(int inv[],const int lim,const int MOD){
    inv[0]=inv[1]=1;for(int i=2;i<=lim;++i)inv[i]=1ll*inv[MOD%i]*(MOD-MOD/i)%MOD;
}
template<class T>void fwrit(const T x){
    if(x<0)return (void)(putchar('-'),fwrit(-x));
    if(x>9)fwrit(x/10);
    putchar(x%10^48);
}
inline LL mulMod(const LL a,const LL b,const LL mod){//long long multiplie_mod
    return ((a*b-(LL)((long double)a/mod*b+1e-8)*mod)%mod+mod)%mod;
}

const int MAXN=50000;
const int logMAXN=15;
const int INF=(1<<30)-1;

char s[MAXN+5];
int sa[MAXN+5],rnk[MAXN+5][2],h[MAXN+5],now,n;
int tax[MAXN+5],tp[MAXN+5],siz;
int T;

inline void Init(){
    scanf("%s",s+1);
    n=strlen(s+1);
    s[++n]='\0';
    rep(i,0,MAXN)tax[i]=0;
    rep(i,0,MAXN)h[i]=0;
}

inline void Getsa(){
    rep(i,1,n)++tax[rnk[i][now]=s[i]];
    rep(i,1,siz=256)tax[i]+=tax[i-1];
    fep(i,n,1)sa[tax[rnk[i][now]]--]=i;
    for(int k=1,x=0;k<n;k<<=1,x=0){
        rep(i,n-k+1,n)tp[++x]=i;
        rep(i,1,n)if(sa[i]>k)tp[++x]=sa[i]-k;
        rep(i,0,siz)tax[i]=0;
        rep(i,1,n)++tax[rnk[i][now]];
        rep(i,1,siz)tax[i]+=tax[i-1];
        fep(i,n,1)sa[tax[rnk[tp[i]][now]]--]=tp[i];
        siz=0;
        rep(i,1,n){++siz;
            if(i!=1 && rnk[sa[i-1]][now]==rnk[sa[i]][now] && rnk[sa[i-1]+k][now]==rnk[sa[i]+k][now])--siz;
            rnk[sa[i]][now^1]=siz;
        }
        now^=1;
        if(siz==n)break;
    }
}

inline void Gethe(){
    int k=0,j;
    rep(i,1,n){
        if(rnk[i][now]==1)continue;
        j=sa[rnk[i][now]-1];
        if(k)--k;
        while(i+k<=n && j+k<=n && s[i+k]==s[j+k])++k;
        h[rnk[i][now]]=k;
    }
}

inline void Solve(){
    int ans=n*(n-1)/2;
    for(int i=3;i<=n;++i)ans-=h[i];
    writc(ans,'\n');
}

signed main(){
#ifdef FILEOI
    freopen("file.in","r",stdin);
    freopen("file.out","w",stdout);
#endif
    T=qread();
    while(T--){
        Init();
        Getsa();
        Gethe();
        Solve();
    }
    return 0;
}
posted @ 2020-02-28 15:57  Arextre  阅读(140)  评论(0编辑  收藏  举报