BZOJ3796 Mushroom追妹纸(二分答案+后缀数组+KMP)
求出一个串使得这个串是\(s1,s2\)的子串。串中不包含\(s3\)。
如果没有这个\(s3\)就可以二分答案,然后height小于二分值分一组。看看每组里是不是出现过\(s1,s2\)的后缀。判断就行。
然后有了\(s3\)之后,我们考虑改变一下height数组。
我们把\(s1s2\)拼在一起构成一个新串\(s\)。(中间隔一个#)
设\(s3\)的长度为\(len\)。
显然对于s中出现\(s3\)的起始位置\(x\)。\(height[rk[x]]\)要小于\(len\),\(height[rk[x-1]\)要小于\(len+1\),\(height[rk[x-2]]\)要小于\(len+2\)...
这样复杂度可能是\(O(n^2)\)的,其实我们只需要把初值为INF的\(g[x]\)设为\(len\),然后倒着做一遍\(g[i]=min(g[i],g[i+1])\),\(height[rk[i]]=min(height[rk[i]],g[i])\)就行。
所以我们需要知道\(s\)中\(s3\)的起始位置。这个用KMP解决。
然后愉快的二分就行了。
#include<iostream>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<cstdio>
using namespace std;
const int N=201010;
const int INF=1e9;
int sa[N],n,m,c[N],x[N],y[N],rk[N],height[N];
int n1,n2,n3,nxt[N],g[N],ans;
char s1[N],s2[N],s3[N],s[N];
void get_sa(){
for(int i=1;i<=n;i++)c[x[i]=s[i]]++;
for(int i=1;i<=m;i++)c[i]+=c[i-1];
for(int i=n;i>=1;i--)sa[c[x[i]]--]=i;
for(int k=1;k<=n;k<<=1){
int num=0;
for(int i=n-k+1;i<=n;i++)y[++num]=i;
for(int i=1;i<=n;i++)if(sa[i]>k)y[++num]=sa[i]-k;
for(int i=1;i<=m;i++)c[i]=0;
for(int i=1;i<=n;i++)c[x[i]]++;
for(int i=1;i<=m;i++)c[i]+=c[i-1];
for(int i=n;i>=1;i--)sa[c[x[y[i]]]--]=y[i],y[i]=0;
for(int i=1;i<=n;i++)swap(x[i],y[i]);
x[sa[1]]=1;num=1;
for(int i=2;i<=n;i++)
x[sa[i]]=(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+k]==y[sa[i-1]+k])?num:++num;
if(num==n)break;
m=num;
}
}
void get_height(){
int k=0;
for(int i=1;i<=n;i++)rk[sa[i]]=i;
for(int i=1;i<=n;i++){
if(rk[i]==1)continue;
if(k)k--;
int j=sa[rk[i]-1];
while(i+k<=n&&j+k<=n&&s[j+k]==s[i+k])k++;
height[rk[i]]=k;
}
for(int i=1;i<=n;i++)height[rk[i]]=min(height[rk[i]],g[i]);
}
bool judge(int x){
bool flag1=false,flag2=false;
for(int i=1;i<=n;i++){
if(height[i]<x){
flag1=flag2=false;
if(sa[i]<=n1)flag1=true;
else flag2=true;
}
else{
if(sa[i]<=n1)flag1=true;
else flag2=true;
if(flag1&&flag2)return true;
}
}
return false;
}
int main(){
scanf("%s",s1+1);
scanf("%s",s2+1);
n1=strlen(s1+1);
n2=strlen(s2+1);
for(int i=1;i<=n1;i++)s[i]=s1[i];
s[n1+1]='#';
for(int i=1;i<=n2;i++)s[n1+i+1]=s2[i];
n=n1+n2+1;
scanf("%s",s3+1);
n3=strlen(s3+1);
for(int i=2,j=0;i<=n3;i++){
while(j&&s3[j+1]!=s3[i])j=nxt[j];
if(s3[j+1]==s3[i])j++;
nxt[i]=j;
}
for(int i=1;i<=n;i++)g[i]=INF;
for(int i=1,j=0;i<=n;i++){
while(j&&(s3[j+1]!=s[i]||j==n3))j=nxt[j];
if(s3[j+1]==s[i])j++;
if(j==n3)g[i-n3+1]=n3-1;
}
for(int i=n;i>=1;i--)g[i]=min(g[i],g[i+1]+1);
m=122;get_sa();get_height();
int l=1,r=n;
while(l<=r){
int mid=(l+r)>>1;
if(judge(mid)){
ans=mid;
l=mid+1;
}
else r=mid-1;
}
printf("%d",ans);
return 0;
}