[CF666E]Forensic Examination
题意
给你一个串\(S\)以及一个字符串数组\(T[1..m]\),\(q\)次询问,每次问\(S\)的子串\(S[p_l..p_r]\)在\(T[l..r]\)中的哪个串里的出现次数最多,并输出出现次数。
如有多解输出最靠前的那一个。
\(|S|\le5*10^5,\sum|T|\le5*10^4,q\le5*10^5\)
sol
首先肯定是对\(T[1..m]\)这个字符串数组构出广义\(SAM\)。
考虑串\(S\)在这个\(SAM\)上的匹配,是\(O(|S|)\)的。显然不能对于每组询问都暴力匹配一遍。
把所有询问离线,按照询问的\(p_r\)挂链,当原串在\(SAM\)上匹配到\(p_r\)位置的时候,这个子串\(S[p_l..p_r]\)一定会是当前状态的一个祖先(当然也有可能这个子串根本就没在\(SAM\)里面出现,这个特判掉就好了)。
找祖先显然可以倍增一下,跳到最上方的满足\(len_u\ge q_r-q_l+1\)的点\(u\)就好了。
我们现在已经找到了这个对应状态,现在需要知道这个状态分别在哪些串里面出现了多少次。对每个节点开一棵线段树,线段树以字符串数组的下表为下标\((1..m)\),值表示当前这个状态在字符串\(T[i]\)中的出现次数。需要支持查询区间最小值以及位置。
一开始线段树是只有底层状态的,对于更上层的状态如何处理?
线段树合并即可。
code
挂链的东西太多了所以直接开了int nxt[3][N],head[3][N];
。
\(0\)是后缀树上对父亲挂的链,\(1\)是询问对右端点挂的链,\(2\)是询问对倍增跳到的状态节点挂的链。
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
int gi()
{
int x=0,w=1;char ch=getchar();
while ((ch<'0'||ch>'9')&&ch!='-') ch=getchar();
if (ch=='-') w=0,ch=getchar();
while (ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return w?x:-x;
}
const int N = 5e5+5;
struct data{
int x,y;
bool operator < (const data &b) const
{return x<b.x||x==b.x&&y>b.y;}
}ans[N];
struct seg{int ls,rs;data v;}t[N*20];
struct node{int l,r,pl,pr;}q[N];
int n,m,Q,last=1,tot=1,tr[N][26],fa[22][N],len[N],rt[N],Node;
int nxt[3][N],head[3][N];
char S[N],T[N];
void extend(int c)
{
int v=last,u=++tot;last=u;
len[u]=len[v]+1;
while (v&&!tr[v][c]) tr[v][c]=u,v=fa[0][v];
if (!v) fa[0][u]=1;
else{
int x=tr[v][c];
if (len[x]==len[v]+1) fa[0][u]=x;
else{
int y=++tot;
memcpy(tr[y],tr[x],sizeof(tr[y]));
fa[0][y]=fa[0][x];fa[0][x]=fa[0][u]=y;len[y]=len[v]+1;
while (v&&tr[v][c]==x) tr[v][c]=y,v=fa[0][v];
}
}
}
void modify(int &x,int l,int r,int p)
{
if (!x) x=++Node;
if (l==r) {t[x].v.x++;t[x].v.y=l;return;}
int mid=l+r>>1;
if (p<=mid) modify(t[x].ls,l,mid,p);
else modify(t[x].rs,mid+1,r,p);
t[x].v=max(t[t[x].ls].v,t[t[x].rs].v);
}
void merge(int &x,int y)
{
if (!x||!y) {x=x|y;return;}
if (!t[x].ls&&!t[x].rs) {t[x].v.x+=t[y].v.x;return;}//叶子节点,直接累加
merge(t[x].ls,t[y].ls);merge(t[x].rs,t[y].rs);
t[x].v=max(t[t[x].ls].v,t[t[x].rs].v);
}
data query(int x,int l,int r,int ql,int qr)
{
if (l>=ql&&r<=qr) return t[x].v;
int mid=l+r>>1;
if (qr<=mid) return query(t[x].ls,l,mid,ql,qr);
if (ql>mid) return query(t[x].rs,mid+1,r,ql,qr);
return max(query(t[x].ls,l,mid,ql,qr),query(t[x].rs,mid+1,r,ql,qr));
}
void dfs(int u)
{
for (int i=head[0][u];i;i=nxt[0][i])
dfs(i),merge(rt[u],rt[i]);
for (int i=head[2][u];i;i=nxt[2][i])
ans[i]=query(rt[u],1,m,q[i].l,q[i].r);
}
int main()
{
scanf("%s",S+1);n=strlen(S+1);
m=gi();
for (int i=1;i<=m;++i)
{
scanf("%s",T+1);int len=strlen(T+1);last=1;
for (int j=1;j<=len;++j) extend(T[j]-'a'),modify(rt[last],1,m,i);
}
Q=gi();
for (int i=1;i<=Q;++i)
{
q[i]=(node){gi(),gi(),gi(),gi()};
nxt[1][i]=head[1][q[i].pr],head[1][q[i].pr]=i;
}
for (int i=2;i<=tot;++i)
nxt[0][i]=head[0][fa[0][i]],head[0][fa[0][i]]=i;
for (int j=1;j<22;++j)
for (int i=2;i<=tot;++i)
fa[j][i]=fa[j-1][fa[j-1][i]];
for (int i=1,now=1,cnt=0;i<=n;++i)
{
while (now&&!tr[now][S[i]-'a']) now=fa[0][now],cnt=len[now];
if (!now) {now=1;cnt=0;continue;}
now=tr[now][S[i]-'a'];++cnt;
for (int j=head[1][i];j;j=nxt[1][j])
{
int u=now;if (cnt<q[j].pr-q[j].pl+1) continue;
for (int k=21;~k;--k)
if (len[fa[k][u]]>=q[j].pr-q[j].pl+1)
u=fa[k][u];
nxt[2][j]=head[2][u];head[2][u]=j;
}
}
dfs(1);
for (int i=1;i<=Q;++i)
{
if (!ans[i].x) ans[i].y=q[i].l;
printf("%d %d\n",ans[i].y,ans[i].x);
}
return 0;
}