YbtOJ「字符串算法」第3章 后缀自动机 F. 广义后缀树 --zhengjun
题目大意
给定 \(n\) 个模板串,以及 \(m\) 个查询串,依次查询每一个查询串是多少个模板串的子串。
思路
蒟蒻不会广义 SAM,所以只能用 SA + 莫队的高复杂度笨重算法通过了。
首先把每个模板串和查询串拼接在一起(中间用不同的字符隔开),然后跑一遍 SA,对于每个查询串,记 \(i\) 为该查询串的后缀编号,\(len\) 为该查询串的长度,那么就是要找到所有的 \(j\),使得 \(LCP(i,j)\ge len\),显然这样的 \(j\) 的后缀排名 \(rk_j\) 在一段区间内,然后我们可以在拼接字符串的时候,将分隔的字符从小到大放在相邻的字符串之间,这样与当前的查询串相同的子串所对应的后缀排名就一定在后缀 \(i\) 的前面,这样,\(height_{rk_i}\) 就一定要 \(=len\)(没有 \(>len\) 的情况),然后只需要处理出每个排名为 \(i\) 的后缀,之前之后有多少个 \(j\) 满足 \(height_j\ge height_i\),这一步可以用单调栈维护。
然后,问题就转换成了在一段区间内统计颜色个数,只需要莫队或者树状数组即可。
代码
#include<bits/stdc++.h>
using namespace std;typedef long long ll;const int N=1e6+10;string a;struct ques{int l,r,id;}que[N];
int B,n,m,q,k,s[N],rk[N],old[N],sa[N],h[N],cnt[N],id[N],p[N],l[N],r[N],pos[N],st[N],stk[N],top,now,ans[N],len[N];
bool cmp(ques x,ques y){return (x.l-1)/B^(y.l-1)/B?(x.l-1)/B<(y.l-1)/B:x.r<y.r;}
void getsa(int n,int m){
for(int i=1;i<=n;i++)cnt[rk[i]=s[i]]++;for(int i=1;i<=m;i++)cnt[i]+=cnt[i-1];
for(int i=n;i>=1;i--)sa[cnt[rk[i]]--]=i;for(int i=1;i<=m;i++)cnt[i]=0;for(int len=1,k;len==1||m^n;m=k,len<<=1){
k=0;for(int i=n-len+1;i<=n;i++)p[++k]=i;for(int i=1;i<=n;i++)if(sa[i]>len)p[++k]=sa[i]-len;
for(int i=1;i<=n;i++)cnt[id[i]=rk[p[i]]]++;for(int i=1;i<=m;i++)cnt[i]+=cnt[i-1];
for(int i=n;i>=1;i--)sa[cnt[id[i]]--]=p[i];for(int i=1;i<=m;i++)cnt[i]=0;for(int i=1;i<=n;i++)old[i]=rk[i];
k=0;for(int i=1;i<=n;i++)rk[sa[i]]=old[sa[i]]==old[sa[i-1]]&&old[sa[i]+len]==old[sa[i-1]+len]?k:++k;
}for(int i=1,k=0;i<=n;i++){if(k)k--;while(max(i,sa[rk[i]-1])+k<=n&&s[i+k]==s[sa[rk[i]-1]+k])k++;h[rk[i]]=k;}
}
int main(){
scanf("%d%d",&n,&q);k=128;for(int i=1;i<=n;i++){cin>>a;st[i]=m+1;for(char x:a)s[++m]=x,pos[m]=i;s[++m]=++k;pos[m]=0;}
for(int i=1;i<=q;i++){cin>>a;len[i]=a.length();st[i+n]=m+1;for(char x:a)s[++m]=x,pos[m]=0;s[++m]=++k;pos[m]=0;}
getsa(m,k);stk[top=0]=1;for(int i=2;i<=m;i++){while(top&&h[stk[top]]>=h[i])top--;l[i]=stk[top];stk[++top]=i;}
stk[top=0]=m+1;for(int i=m;i>=2;i--){while(top&&h[stk[top]]>=h[i])top--;r[i]=stk[top];stk[++top]=i;}B=sqrt(n);
for(int i=1,cur;cur=rk[st[i+n]],i<=q;i++)que[i]={l[cur],r[cur]-1,i};sort(que+1,que+1+q,cmp);for(int i=1,l=1,r=0;i<=q;i++){
while(l>que[i].l)now+=!cnt[pos[sa[--l]]]++;while(r<que[i].r)now+=!cnt[pos[sa[++r]]]++;
while(l<que[i].l)now-=!--cnt[pos[sa[l++]]];while(r>que[i].r)now-=!--cnt[pos[sa[r--]]];ans[que[i].id]=now-(cnt[0]>0);
}for(int i=1,cur;cur=rk[st[i+n]],i<=q;i++)if(h[cur]>=len[i])printf("%d\n",ans[i]);else puts("0");return 0;
}