P4770 [NOI2018]你的名字(SAM+线段树合并)
题目链接
https://www.luogu.com.cn/problem/P4770
题意
给出一个字符串\(S\),\(Q\)组询问,每次询问字符串\(T\)中有多少个本质不同的子串在\(S[l..r]\)中没有出现。
思路
问题转化为有几个\(T\)的本质不同的子串满足这个子串在\(S\)的给定区间当中出现了,然后我们求出这个东西之后拿\(T\)的所有本质不同子串去减就能得到答案了
先考虑一下\(l=1,r=|S|\)的情况
设\(lim[i]\)表示字符串\(T[1..i]\)能在\(S\)中匹配到的最长后缀(即\(T[i−lim[i]+1,i]\)是\(S\)的子串且\(lim[i]\)最大)
这个\(lim[i]\)可以不断地在\(S\)的后缀自动机上跳来求出。当无法向下匹配时,一直跳\(parent\)树直到可以匹配为止。
再对\(T\)建后缀自动机
那么答案就是:\(ans=\sum_{i=2}^{cnt}max(0,len[i]−max(len[fa[i]],lim[pos[i]]))\)
其中\(pos[i]\)表示该集合字符串出现的位置(随便一个都行,这里取最大的位置)
然后\(l\)和\(r\)任意的情况用线段树维护\(S\)后缀自动机的\(endpos\)集合就行了
在求\(lim[i]\)时要注意一点,匹配失败不能直接跳\(parent\),节点集合中的字符串要逐个匹配。
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxx = 2*1e6+10;
struct SAM
{
int last,tot,fa[maxx],ch[maxx][26],len[maxx],pos[maxx];
void init()
{
last=tot=1;
memset(ch[1],0,sizeof(ch[1]));
}
void add(int x)
{
int pre=last,now=last=++tot;
memset(ch[now],0,sizeof(ch[now]));
len[now]=len[pre]+1;
pos[now]=len[now];
for(;pre&&!ch[pre][x];pre=fa[pre])ch[pre][x]=now;
if(!pre)fa[now]=1;
else
{
int q=ch[pre][x];
if(len[q]==len[pre]+1)fa[now]=q;
else
{
int nows=++tot;
memset(ch[nows],0,sizeof(ch[nows]));
len[nows]=len[pre]+1;
pos[nows]=pos[now];
memcpy(ch[nows],ch[q],sizeof(ch[q]));
fa[nows]=fa[q];fa[q]=fa[now]=nows;
for(;pre&&ch[pre][x]==q;pre=fa[pre])ch[pre][x]=nows;
}
}
}
}s1,s2;
int sum[50*maxx],ls[50*maxx],rs[50*maxx],rt[maxx],cnt;
int head[maxx],to[maxx],ne[maxx],num;
int n;
void update(int &u,int l,int r,int x)
{
if(!u)u=++cnt;
if(l==r)
{
sum[u]=1;
return;
}
int mid=(l+r)/2;
if(x<=mid)update(ls[u],l,mid,x);
else update(rs[u],mid+1,r,x);
sum[u]=sum[ls[u]]+sum[rs[u]];
}
int query(int u,int l,int r,int p,int q)
{
if(p>q)return 0;
if(!u)return 0;
if(p<=l&&r<=q)return sum[u];
int mid=(l+r)/2;
int ans=0;
if(p<=mid)ans+=query(ls[u],l,mid,p,q);
if(q>mid)ans+=query(rs[u],mid+1,r,p,q);
return ans;
}
int merge(int a,int b,int l,int r)
{
if(!a)return b;
if(!b)return a;
int u=++cnt;
if(l==r)
{
sum[u]=sum[a]|sum[b];
return u;
}
int mid=(l+r)/2;
ls[u]=merge(ls[a],ls[b],l,mid);
rs[u]=merge(rs[a],rs[b],mid+1,r);
sum[u]=sum[ls[u]]+sum[rs[u]];
return u;
}
void addm(int u,int v)
{
to[++num]=v,ne[num]=head[u],head[u]=num;
}
void dfs(int u)
{
for(int i=head[u];i;i=ne[i])
{
dfs(to[i]);
rt[u]=merge(rt[u],rt[to[i]],1,n);
}
}
char s[maxx],t[maxx];
int lim[maxx];
int main()
{
scanf("%s",s+1);
n=strlen(s+1);
s1.init();
for(int i=1;i<=n;i++)
{
update(rt[s1.tot+1],1,n,i);
s1.add(s[i]-'a');
}
for(int i=2;i<=s1.tot;i++)addm(s1.fa[i],i);
dfs(1);
int q,l,r;
scanf("%d",&q);
while(q--)
{
scanf("%s%d%d",t+1,&l,&r);
int m=strlen(t+1),now=1,len=0;
for(int i=1;i<=m;i++)
{
int x=t[i]-'a';
while(now&&!(s1.ch[now][x]&&query(rt[s1.ch[now][x]],1,n,l+len,r)))
{
if(len>s1.len[s1.fa[now]])len--;
else now=s1.fa[now];
}
if(!now)now=1,len=0;
else len++,now=s1.ch[now][x];
lim[i]=len;
}
s2.init();
for(int i=1;i<=m;i++)s2.add(t[i]-'a');
LL ans=0;
for(int i=2;i<=s2.tot;i++)
ans+=max(0,s2.len[i]-max(s2.len[s2.fa[i]],lim[s2.pos[i]]));
printf("%lld\n",ans);
}
return 0;
}