BZOJ3998 TJOI2015弦论(后缀数组+二分答案)

  先看t=1的情况。显然得求出SA(因为我不会SAM)。我们一位位地确定答案。设填到了第len位,二分这一位填什么之后,在已经确定的答案所在的范围(SA上的某段区间)内二分,找到最后一个小于当前串的后缀,那么从区间左端点到该位置的这些后缀的所有前缀都要比二分出的答案小,判一下是否合法。确定了这一位填什么之后,还要找到最后一个前len位小于等于当前串的后缀,若加上这一部分后比答案串小的已经超过k个的话,则答案已经确定可以直接退出了,否则将这些计入并继续填下一位,更新这些前len位等于答案串的后缀为答案范围。注意算的时候要减去已经计入的部分。

  t=0类似,算出height数组后可以去掉重复子串。注意一些细节就好。

  时间复杂度O(nlognlog|s|),s为字符集大小也就是26,跑的还挺快。

  我怎么还不学SAM啊。

#include<iostream> 
#include<cstdio>
#include<cmath>
#include<cstdlib>
#include<cstring>
#include<algorithm>
using namespace std;
int read()
{
    int x=0,f=1;char c=getchar();
    while (c<'0'||c>'9') {if (c=='-') f=-1;c=getchar();}
    while (c>='0'&&c<='9') x=(x<<1)+(x<<3)+(c^48),c=getchar();
    return x*f;
}
#define N 500010
char a[N],ans[N];
int T,k,n,sa[N],sa2[N<<1],rk[N<<1],tmp[N<<1],cnt[N],h[N];
long long sum[N];
void make()
{
    memset(cnt,0,sizeof(cnt));
    int m=0;
    for (int i=1;i<=n;i++) cnt[rk[i]=a[i]]++,m=max(m,(int)a[i]);
    for (int i=1;i<=m;i++) cnt[i]+=cnt[i-1];
    for (int i=n;i>=1;i--) sa[cnt[rk[i]]--]=i;
    for (int k=1;k<=n;k<<=1)
    {
        int p=0;
        for (int i=n-k+1;i<=n;i++) sa2[++p]=i;
        for (int i=1;i<=n;i++) if (sa[i]>k) sa2[++p]=sa[i]-k;
        memset(cnt,0,m+1<<2);
        for (int i=1;i<=n;i++) cnt[rk[i]]++;
        for (int i=1;i<=m;i++) cnt[i]+=cnt[i-1];
        for (int i=n;i>=1;i--) sa[cnt[rk[sa2[i]]]--]=sa2[i];
        memcpy(tmp,rk,sizeof((rk)));
        p=1;rk[sa[1]]=1;
        for (int i=2;i<=n;i++)
        {
            if (tmp[sa[i]]!=tmp[sa[i-1]]||tmp[sa[i]+k]!=tmp[sa[i-1]+k]) p++;
            rk[sa[i]]=p;
        }
        if (p>=n) break;
        m=p;
    }
}
int lower_find(int len,int l,int r,char c)
{
    int ans=l-1;
    while (l<=r)
    {
        int mid=l+r>>1;
        if (a[sa[mid]+len-1]<c) ans=mid,l=mid+1;
        else r=mid-1;
    }
    return ans;
}
int upper_find(int len,int l,int r,char c)
{
    int ans=l;
    while (l<=r)
    {
        int mid=l+r>>1;
        if (a[sa[mid]+len-1]<=c) ans=mid,l=mid+1;
        else r=mid-1;
    }
    return ans;
}
int main()
{
    freopen("bzoj3998.in","r",stdin);
    freopen("bzoj3998.out","w",stdout);
    scanf("%s",a+1);n=strlen(a+1);
    cin>>T>>k;
    make();
    int len=0;
    if (T==1)
    {
        for (int i=1;i<=n;i++) sum[i]=sum[i-1]+n-sa[i]+1;
        int l=1,r=n;
        while (len<=n)
        {
            len++;
            char lc='a',rc='z';
            while (lc<=rc)
            {
                char midc=lc+rc>>1;
                int x=lower_find(len,l,r,midc);
                if (sum[x]-sum[l-1]-(len-1)*(x-l+1)<k) lc=midc+1,ans[len]=midc;
                else rc=midc-1;
            }
            int x=lower_find(len,l,r,ans[len]),y=upper_find(len,l,r,ans[len]);
            k-=sum[x]-sum[l-1]-(len-1)*(x-l+1)+y-x;
            if (k<=0) break;
            l=x+1,r=y;
        }
    }
    else
    {
        for (int i=1;i<=n;i++)
        {
            h[i]=max(h[i-1]-1,0);
            while (a[i+h[i]]==a[sa[rk[i]-1]+h[i]]) h[i]++;
        }
        for (int i=1;i<=n;i++) sum[i]=sum[i-1]+n-sa[i]+1-h[sa[i]];
        int l=1,r=n;
        while (len<=n)
        {
            len++;
            char lc='a',rc='z';
            while (lc<=rc)
            {
                char midc=lc+rc>>1;
                int x=lower_find(len,l,r,midc);
                if ((x==l-1?0:sum[x]-sum[l-1]+h[sa[l]]-len+1)<k) lc=midc+1,ans[len]=midc;
                else rc=midc-1;
            }
            int x=lower_find(len,l,r,ans[len]),y=upper_find(len,l,r,ans[len]);
            k-=(x==l-1?0:sum[x]-sum[l-1]+h[sa[l]]-len+1)+1;
            if (k<=0) break;
            l=x+1,r=y;
        }
    }
    if (k>0) cout<<-1;
    else for (int i=1;i<=len;i++) printf("%c",ans[i]);
    fclose(stdin);fclose(stdout);
    return 0;
}

 

posted @ 2018-08-02 13:16  Gloid  阅读(239)  评论(0编辑  收藏  举报