[2016北京集训测试赛1]兔子的字符串-[后缀数组+二分]
Description
Solution
由于这道题很难计算出一个答案,我们考虑二分。
既然要二分,我们需要能在很短时间内求出字符串的大小关系,可以考虑后缀数组(它可以直接把后缀排序,还可以算相邻串的公共前缀)。
将所有的后缀从小到大排完序后,我们二分某个后缀,使它为答案,判断划分的段数。
假如我们目前查找到的串的左端点固定,易知其右端点向右挪的过程中,答案绝对不会变小。(满足贪心性质)
贪心,假如目前二分到后缀i,定义dis[j],满足s[j,dis[j]+j-1]内的字典序最大的子串<=后缀i,且dis[j]要尽量大。
则排名在i之前(即字典序比后缀i小)的后缀,dis[j]直接为len(s的长度)就好,因为它可以延伸到字符串末尾的,你设为正无穷也可以,反正它绝对不会影响最后划分的段数。但是字典序比i大的后缀j,它的dis[j]只能为它和后缀i的最大公共前缀了。
确定了答案在某个后缀i里后,在后缀i里继续二分即可。
Code
#include<iostream> #include<cstdio> #include<cstring> #include<cmath> using namespace std; int n,sa[100010],rk[100010],height[100010],cnt[100010]; char s[100010];int k; bool check(int *x,int a,int b,int j) { if (a+j>n&&b+j>n) return x[a]==x[b]; if (a+j>n||b+j>n) return 0; return x[a]==x[b]&&x[a+j]==x[b+j]; } void build_sa() { int *x=rk,*y=height,m='z'; for (int i=1;i<=n;i++) cnt[x[i]=s[i]]++; for (int i=1;i<=m;i++) cnt[i]+=cnt[i-1]; for (int i=n;i;i--) sa[cnt[x[i]]--]=i; for (int j=1,pos=0;pos<n;m=pos,j<<=1) { pos=0; for (int i=n-j+1;i<=n;i++) y[++pos]=i; for (int i=1;i<=n;i++) if (sa[i]>j) y[++pos]=sa[i]-j; for (int i=1;i<=m;i++) cnt[i]=0; for (int i=1;i<=n;i++) cnt[x[y[i]]]++; for (int i=1;i<=m;i++) cnt[i]+=cnt[i-1]; for (int i=n;i;i--) sa[cnt[x[y[i]]]--]=y[i]; swap(x,y); pos=1;x[sa[1]]=1; for (int i=2;i<=n;i++) x[sa[i]]=check(y,sa[i-1],sa[i],j)?pos:++pos; } } void get_height() { int k=0; for (int i=1;i<=n;i++) rk[sa[i]]=i; for (int i=1;i<=n;height[rk[i++]]=k) for (k?--k:0;s[i+k]==s[sa[rk[i]-1]+k];k++); } int dis[100010]; int solve(int l,int r) { for (int i=1;i<=n;i++) dis[i]=n; dis[l]=r-l+1; for (int i=rk[l]+1;i<=n;i++) { dis[sa[i]]=min(dis[sa[i-1]],height[i]); if (!dis[sa[i]]) return n+1; } int mn=n+1,num=0; for (int i=1;i<=n;i++) { if (mn<i) num++,mn=n+1; mn=min(mn,i+dis[i]-1); } num++; return num; } int main() { scanf("%d%s",&k,s+1); n=strlen(s+1); build_sa(); get_height(); int l=1,r=n,mid; while (l<r) { mid=(l+r)/2; if (solve(sa[mid],n)>k) l=mid+1;else r=mid; } l=sa[l]; int l1=1,r1=n-l+1; while (l1<r1) { mid=(l1+r1)/2; if (solve(l,l+mid-1)>k) l1=mid+1; else r1=mid; } for (int i=l;i<l+l1;i++) printf("%c",s[i]); }