【BZOJ5417】你的名字(NOI2018)-后缀自动机+主席树
测试地址:你的名字
做法:本题需要用到后缀自动机+主席树。
首先考虑的情况。考虑的每个前缀的贡献,我们需要找到它最短的没在中出现过,而且没在的前面部分出现过的后缀,这样包含它的所有后缀就都是合法的贡献了。显然这个后缀的长度等同于,在中出现过或者在的前面部分出现过的最长的后缀的长度加。因此我们分开考虑两种情况。对于在的前面部分出现过的最长后缀,可以通过对建后缀自动机,利用后缀树的性质,决定每个点所代表的串在哪个前缀中第一次出现,DFS一遍统计贡献即可。而对于在中出现过的最长后缀,显然就是用在的后缀自动机中匹配,匹配的长度就是最长后缀长度。这样我们就能解决68分的部分了。
再考虑拓展到任意的情况。的变化对上面的做法有影响的只有在自动机中的匹配。其实没什么大的变化,只不过我们每次需要判断走到的点合不合法。具体地,我们是需要判断,当前匹配长度为的情况下,在末尾添加一个字符后,能不能走到下一个点。
首先如果目前的匹配点完全没有字符的指针,显然就不能继续走,要根据后缀链接跳回去再进行试探。如果有字符的指针,也不一定就能走,因为那个字符串虽然出现在了中,但不一定出现在指定的区间中。这时候要判断在某个区间中存不存在当前字符串。根据后缀自动机的性质,每个点有一个集合,即该点所代表的字符串可能出现的右端点的集合,而这个集合显然就是这个点在后缀树上的子树中所包含的所有前缀节点。我们要判断的是,集合中存不存在一个,使得,而且,即。如果我们对后缀树维护一个DFS序,那么这就变成了在区间中询问存不存在一个权值在某区间内的数,这就是主席树的经典应用了,所以我们使用主席树来判断能不能往下走。
还有一点要注意,如果当前不能走,并不是直接通过后缀链接回退到上面的点,因为当前点还是可能有很多长度不一的后缀的,那么现在的问题就是在长度集合中,最大的使得权值区间在上述算法的询问中合法的,这个就是我们所回退到的串的长度,这时候再往下走就行了。这个东西可以在主席树上分治得到(需要做点处理,详见代码)。当然,如果找不到这个点,就沿着后缀链接退回上一个点。
由于回退和前进的次数都是线性的,所以以上算法的时间复杂度是,可以通过此题。
总的来说,这道题目考察了后缀自动机本身作为匹配工具的性质,又考察了作为辅助的后缀链接,也就是后缀树的性质,还考察了将高级数据结构——主席树维护DFS序的方法在后缀树上灵活运用的能力,思路自然,码量适中,给出题人点赞(我是认真的)。
(但据说已经是套路题了……)
以下是本人代码:
#include <bits/stdc++.h>
using namespace std;
const int inf=1000000000;
int n,m,tot,last,totp,first[2000010],tote;
int ch[2000010][26],pre[2000010],len[2000010],vis[2000010]={0};
int tim=0,in[1000010],out[1000010],mx[500010];
int totseg=0,rt[500010]={0},seg[10000010]={0},segch[10000010][2]={0};
int qlen,ql,qr;
char s[500010];
struct edge
{
int v,next;
}e[2000010];
void extend(int rt,int c)
{
int p,q,np,nq;
p=last;
np=++tot;
len[np]=len[p]+1;
for(int i=0;i<26;i++)
ch[np][i]=0;
while(p&&!ch[p][c]) ch[p][c]=np,p=pre[p];
if (p)
{
q=ch[p][c];
if (len[p]+1==len[q]) pre[np]=q;
else
{
nq=++tot;
for(int i=0;i<26;i++)
ch[nq][i]=ch[q][i];
len[nq]=len[p]+1;
pre[nq]=pre[q];
pre[q]=pre[np]=nq;
while(p&&ch[p][c]==q) ch[p][c]=nq,p=pre[p];
}
}
else pre[np]=rt;
last=np;
}
void insert(int a,int b)
{
e[++tote].v=b;
e[tote].next=first[a];
first[a]=tote;
}
void build(int start,int n)
{
tot=last=start;
pre[start]=len[start]=0;
len[0]=-1;
for(int i=0;i<26;i++)
ch[start][i]=0;
for(int i=1;i<=n;i++)
{
extend(start,s[i]-'a');
vis[last]=i;
}
tote=0;
for(int i=start;i<=tot;i++)
first[i]=0;
for(int i=start+1;i<=tot;i++)
insert(pre[i],i);
}
void pushup(int no)
{
seg[no]=seg[segch[no][0]]+seg[segch[no][1]];
}
void insert(int &v,int last,int l,int r,int x)
{
v=++totseg;
seg[v]=seg[last];
segch[v][0]=segch[last][0];
segch[v][1]=segch[last][1];
if (l==r) {seg[v]++;return;}
int mid=(l+r)>>1;
if (x<=mid) insert(segch[v][0],segch[last][0],l,mid,x);
else insert(segch[v][1],segch[last][1],mid+1,r,x);
pushup(v);
}
void dfs1(int v)
{
in[v]=tim+1;
if (vis[v])
{
tim++;
insert(rt[tim],rt[tim-1],1,n,vis[v]);
}
for(int i=first[v];i;i=e[i].next)
dfs1(e[i].v);
out[v]=tim;
}
void dfs2(int v)
{
if (!vis[v]) vis[v]=inf;
for(int i=first[v];i;i=e[i].next)
{
dfs2(e[i].v);
vis[v]=min(vis[v],vis[e[i].v]);
}
for(int i=first[v];i;i=e[i].next)
if (vis[v]!=vis[e[i].v])
mx[vis[e[i].v]]=len[v];
if (v==totp+1) mx[vis[v]]=0;
}
int querysum(int vr,int vl,int l,int r,int s,int t)
{
if (l>=s&&r<=t) return seg[vr]-seg[vl];
int mid=(l+r)>>1,ans=0;
if (s<=mid) ans+=querysum(segch[vr][0],segch[vl][0],l,mid,s,t);
if (t>mid) ans+=querysum(segch[vr][1],segch[vl][1],mid+1,r,s,t);
return ans;
}
int querylast(int vr,int vl,int l,int r,int s,int t)
{
if (seg[vr]-seg[vl]==0) return 0;
if (l==r) return l;
int mid=(l+r)>>1,ans=0;
if (t>mid) ans=querylast(segch[vr][1],segch[vl][1],mid+1,r,s,t);
if (ans) return ans;
if (s<=mid) ans=querylast(segch[vr][0],segch[vl][0],l,mid,s,t);
return ans;
}
void solve()
{
int now=1,d=0,l,r;
long long ans=0;
scanf("%d%d",&l,&r);
for(int i=1;i<=qlen;i++)
{
int c=s[i]-'a';
while(now)
{
if (!ch[now][c]) {now=pre[now],d=len[now];continue;}
int s=len[pre[now]]+1,t=min(d,r-l);
if (s>t||l+s>r) {now=pre[now],d=len[now];continue;}
int rta=rt[out[ch[now][c]]],rtb=rt[in[ch[now][c]]-1];
if (l+t+1>r||querysum(rta,rtb,1,n,l+t+1,r)==0)
{
int x=querylast(rta,rtb,1,n,l+s,l+t);
if (!x) {now=pre[now],d=len[now];continue;}
now=ch[now][c],d=x-l+1;break;
}
else {now=ch[now][c],d=t+1;break;}
}
if (!now) now=1,d=0;
ans+=(long long)(i-max(mx[i],d));
}
printf("%lld\n",ans);
}
int main()
{
scanf("%s",s+1);
s[0]='#';
n=strlen(s)-1;
build(1,n);
totp=tot;
dfs1(1);
scanf("%d",&m);
for(int i=1;i<=m;i++)
{
scanf("%s",s+1);
qlen=strlen(s)-1;
for(int j=totp+1;j<=totp+(qlen<<1)+1;j++)
vis[j]=0;
build(totp+1,qlen);
dfs2(totp+1);
solve();
}
return 0;
}