bzoj 3230: 相似子串【SA+st表+二分】
总是犯低级错误,st表都能写错……
正反分别做一遍SA,预处理st表方便查询lcp,然后处理a[i]表示前i个后缀一共有多少个本质不同的子串,这里的子串是按字典序的,所以询问的时候直接在a上二分排名就能得到询问区间,然后用正反st表查lcp即可
#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
const int N=200005;
int n,q,b[N],sa1[N],sa2[N],rk1[N],rk2[N],he1[N],he2[N],st1[20][N],st2[20][N],wa[N],wb[N],wv[N],wsu[N];
long long a[N];
char s[N];
long long read()
{
long long r=0,f=1;
char p=getchar();
while(p>'9'||p<'0')
{
if(p=='-')
f=-1;
p=getchar();
}
while(p>='0'&&p<='9')
{
r=r*10+p-48;
p=getchar();
}
return r*f;
}
bool cmp(int r[],int a,int b,int l)
{
return r[a]==r[b]&&r[a+l]==r[b+l];
}
void saa(char r[],int n,int m,int sa[],int rk[],int he[])
{
int *x=wa,*y=wb;
for(int i=0;i<=m;i++)
wsu[i]=0;
for(int i=1;i<=n;i++)
wsu[x[i]=r[i]]++;
for(int i=1;i<=m;i++)
wsu[i]+=wsu[i-1];
for(int i=n;i>=1;i--)
sa[wsu[x[i]]--]=i;
for(int j=1,p=1;j<=n&&p<n;j<<=1,m=p)
{
p=0;
for(int i=n-j+1;i<=n;i++)
y[++p]=i;
for(int i=1;i<=n;i++)
if(sa[i]>j)
y[++p]=sa[i]-j;
for(int i=1;i<=n;i++)
wv[i]=x[y[i]];
for(int i=0;i<=m;i++)
wsu[i]=0;
for(int i=1;i<=n;i++)
wsu[wv[i]]++;
for(int i=1;i<=m;i++)
wsu[i]+=wsu[i-1];
for(int i=n;i>=1;i--)
sa[wsu[wv[i]]--]=y[i];
swap(x,y);
x[sa[1]]=1;
p=1;
for(int i=2;i<=n;i++)
x[sa[i]]=cmp(y,sa[i-1],sa[i],j)?p:++p;
}
for(int i=1;i<=n;i++)
rk[sa[i]]=i;
for(int i=1,j,k=0;i<=n;he[rk[i++]]=k)
for(k?k--:0,j=sa[rk[i]-1];r[i+k]==r[j+k];k++);
}
int ef(long long x)
{
int l=1,r=n,ans=1;
while(l<=r)
{
int mid=(l+r)>>1;
if(a[mid]>=x)
r=mid-1,ans=mid;
else
l=mid+1;
}
return sa1[ans];
}
long long ques1(int x,int y)
{
if(x==y)
return n-x+1;
int l=min(rk1[x],rk1[y])+1,r=max(rk1[x],rk1[y]),k=b[r-l+1];
return min(st1[k][l],st1[k][r-(1<<k)+1]);
}
long long ques2(int x,int y)
{
if(x==y)
return n-x+1;
int l=min(rk2[x],rk2[y])+1,r=max(rk2[x],rk2[y]),k=b[r-l+1];
return min(st2[k][l],st2[k][r-(1<<k)+1]);
}
int main()
{
scanf("%d%d%s",&n,&q,s+1);
saa(s,n,200,sa1,rk1,he1);
reverse(s+1,s+1+n);
saa(s,n,200,sa2,rk2,he2);
b[0]=-1;
for(int i=1;i<=n;i++)
b[i]=b[i>>1]+1;
for(int i=1;i<=n;i++)
st1[0][i]=he1[i],st2[0][i]=he2[i];
for(int i=1;i<=17;i++)
for(int j=1;j+(1<<i)-1<=n;j++)
{
st1[i][j]=min(st1[i-1][j],st1[i-1][j+(1<<(i-1))]);
st2[i][j]=min(st2[i-1][j],st2[i-1][j+(1<<(i-1))]);
}
for(int i=1;i<=n;i++)
a[i]=a[i-1]+n-sa1[i]+1-he1[i];
// for(int i=1;i<=n;i++)
// cerr<<sa1[i]<<" "<<he1[i]<<" "<<a[i]<<endl;
while(q--)
{
long long x=read(),y=read();
if(max(x,y)>a[n])
{
puts("-1");
continue;
}
long long xl=ef(x),xr=xl+he1[rk1[xl]]-1+(x-a[rk1[xl]-1]),yl=ef(y),yr=yl+he1[rk1[yl]]-1+(y-a[rk1[yl]-1]),xx,yy;
// cerr<<xl<<" "<<xr<<" "<<yl<<" "<<yr<<endl;
xx=min(min(xr-xl+1,yr-yl+1),ques1(xl,yl)),yy=min(min(xr-xl+1,yr-yl+1),ques2(n-xr+1,n-yr+1));
printf("%lld\n",xx*xx+yy*yy);
}
return 0;
}