题解:
先跑一下Sa
然后再用kmp匹配一下哪一些位置不行
然后二分答案
代码:
#include<bits/stdc++.h> const int N=100003; using namespace std; int t[N],a[N],xx[N],yy[N],*x,*y,height[N],rank[N],sa[N]; int n,m,len,len1,pd[N],pd1[N],pd2[N],v[N],p,st[20][N],L[N]; char s1[N],s2[N],s3[N]; void get_fail() { t[0]=-1; int j; for (int i=0;i<len1;i++) { j=t[i]; while (j!=-1&&s3[j]!=s3[i]) j=t[j]; t[i+1]=++j; } } void kmp(char s[N],int a[N],int l) { int i=0; int j=0; while (j<=l) { if (s3[i]==s[j]||i==-1) i++,j++; else i=t[i]; if (i==len1) { a[j-len1]=1; i=t[i]; } } } int cmp(int i,int j,int k) { return y[i]==y[j]&&(i+k>len?-1:y[i+k])==(j+k>len?-1:y[j+k]); } void get_sa() { x=xx; y=yy; int m1=30; for (int i=1;i<=len;i++) v[x[i]=a[i]]++; for (int i=1;i<=m1;i++) v[i]+=v[i-1]; for (int i=len;i>=1;i--) sa[v[x[i]]--]=i; for (int k=1;k<=len;k<<=1) { p=0; for (int i=len-k+1;i<=len;i++) y[++p]=i; for (int i=1;i<=len;i++) if (sa[i]>k) y[++p]=sa[i]-k; for (int i=1;i<=m1;i++) v[i]=0; for (int i=1;i<=len;i++) v[x[y[i]]]++; for (int i=1;i<=m1;i++) v[i]+=v[i-1]; for (int i=len;i>=1;i--) sa[v[x[y[i]]]--]=y[i]; swap(x,y); p=2; x[sa[1]]=1; for (int i=2;i<=len;i++) x[sa[i]]=cmp(sa[i],sa[i-1],k)?p-1:p++; if (p>len) break; m1=p+1; } for (int i=1;i<=len;i++) rank[sa[i]]=i; p=0; for (int i=1;i<=len;i++) { if (rank[i]==1) continue; int j=sa[rank[i]-1]; while (i+p<=len&&j+p<=len&&a[i+p]==a[j+p]) p++; height[rank[i]]=p; p=max(0,p-1); } } int calc(int x,int y) { int k=L[y-x]; return max(st[k][x],st[k][y-(1<<k)+1]); } int divide(int l,int r) { int t=l,ans=r+1; while (l<=r) { int mid=(l+r)/2; if (calc(t,mid)) ans=min(ans,mid),r=mid-1; else l=mid+1; } return ans; } int main() { scanf("%s",s1); n=strlen(s1); scanf("%s",s2); m=strlen(s2); scanf("%s",s3); len1=strlen(s3); get_fail(); kmp(s1,pd1,n); kmp(s2,pd2,m); for (int i=1;i<=n;i++)a[i]=s1[i-1]-'a'+1,pd[i]=pd1[i-1]; a[n+1]=0;len=n+1; for (int i=1;i<=m;i++)a[++len]=s2[i-1]-'a'+1,pd[len]=pd2[i-1]; get_sa(); for (int i=1;i<=len;i++)st[0][i]=pd[i]; for (int i=1;i<=17;i++) for (int j=1;j<=len;j++) if (j+(1<<i)-1<=len)st[i][j]=max(st[i-1][j],st[i-1][j+(1<<(i-1))]); int j=0; for (int i=1;i<=len;i++) { if (1<<(j+1)<=i) j++; L[i]=j; } int ans=0; for (int i=2;i<=len;i++) if (sa[i]<=n&&sa[i-1]>n+1||sa[i]>n+1&&sa[i-1]<=n) { int t=height[i]; int pos=divide(sa[i],sa[i]+height[i]-len1); t=min(t,pos-sa[i]+1+len1-2); ans=max(ans,t); } printf("%d\n",ans); }