【BZOJ3413】匹配(后缀自动机,线段树合并)
【BZOJ3413】匹配(后缀自动机,线段树合并)
题面
题解
很好的一道题目。
做一个转化,匹配的次数显然就是在可以匹配的区间中,每个前缀的出现次数之和。
首先是空前缀的出现次数,意味着你会去匹配第一个字符。
然后是第一个字符的出现次数,意味着你回去匹配前两个字符。
如此下去就是最后的答案。
那么构建\(SAM\)后线段树合并维护好每个点的\(endpos\)。
然后对于询问串在\(SAM\)上跑一遍就好了。
注意下每个\(endpos\)的可行范围到底是哪里,以及最终整个询问串是不需要计算到答案里的。
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
#define MAX 100100
inline int read()
{
int x=0;bool t=false;char ch=getchar();
while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
if(ch=='-')t=true,ch=getchar();
while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
return t?-x:x;
}
int n,m;
char ch[MAX];
struct SegNode{int ls,rs,v;}T[MAX<<6];
int TOT,rt[MAX<<1],lst[MAX<<1];
void Modify(int &x,int l,int r,int p)
{
if(!x)x=++TOT;T[x].v+=1;if(l==r)return;
int mid=(l+r)>>1;
if(p<=mid)Modify(T[x].ls,l,mid,p);
else Modify(T[x].rs,mid+1,r,p);
}
int Merge(int x,int y)
{
if(!x||!y)return x|y;
int z=++TOT;
T[z].ls=Merge(T[x].ls,T[y].ls);
T[z].rs=Merge(T[x].rs,T[y].rs);
T[z].v=T[T[z].ls].v+T[T[z].rs].v;
return z;
}
int Query(int x,int l,int r,int L,int R)
{
if(L>R||!x)return 0;
if(L<=l&&r<=R)return T[x].v;
int mid=(l+r)>>1,ret=0;
if(L<=mid)ret+=Query(T[x].ls,l,mid,L,R);
if(R>mid)ret+=Query(T[x].rs,mid+1,r,L,R);
return ret;
}
struct Node
{
int son[10];
int len,ff;
}t[MAX<<1];
int last=1,tot=1;
void extend(int c,int id)
{
int p=last,np=++tot;last=tot;
t[np].len=t[p].len+1;
while(p&&!t[p].son[c])t[p].son[c]=np,p=t[p].ff;
if(!p)t[np].ff=1;
else
{
int q=t[p].son[c];
if(t[q].len==t[p].len+1)t[np].ff=q;
else
{
int nq=++tot;
t[nq]=t[q];t[nq].len=t[p].len+1;
t[q].ff=t[np].ff=nq;
while(p&&t[p].son[c]==q)t[p].son[c]=nq,p=t[p].ff;
}
}
Modify(rt[np],1,n,id);lst[np]=id;
}
int p[MAX<<1],a[MAX<<1];
int check(char *ch)
{
int now=1,l=strlen(ch+1);
for(int i=1;i<=l;++i)
{
int c=ch[i]-48;
if(t[now].son[c])now=t[now].son[c];
else return -1;
}
return lst[now];
}
int main()
{
n=read();scanf("%s",ch+1);memset(lst,63,sizeof(lst));
for(int i=1;i<=n;++i)extend(ch[i]-48,i);
for(int i=1;i<=tot;++i)a[t[i].len]++;
for(int i=1;i<=n;++i)a[i]+=a[i-1];
for(int i=1;i<=tot;++i)p[a[t[i].len]--]=i;
for(int i=tot;i;--i)
if(t[p[i]].ff)
{
rt[t[p[i]].ff]=Merge(rt[t[p[i]].ff],rt[p[i]]);
lst[t[p[i]].ff]=min(lst[t[p[i]].ff],lst[p[i]]);
}
m=read();
while(m--)
{
scanf("%s",ch+1);
int l=strlen(ch+1),h=check(ch),ans;
if(h==-1)ans=n;
else ans=h+1-l;
for(int i=1,now=1;i<l;++i)
{
int c=ch[i]-48;
if(t[now].son[c])now=t[now].son[c];
else break;
if(h==-1)ans+=Query(rt[now],1,n,1,n);
else ans+=Query(rt[now],1,n,1,h-l+i);
}
printf("%d\n",ans);
}
return 0;
}