【瞎口胡】AC 自动机
AC 自动机用来解决多模式串匹配问题。
以下便是一个经典问题:
给定 \(n\) 个模式串 \(S_1,S_2,...,S_n\) 和一个文本串 \(T\)。问有多少个模式串在文本串中出现过。
\(\sum |S_i| \leq 10^6,|T| \leq 10^6\)
考虑对模式串建出 trie。在 trie 的每个节点额外记录一个 fail,表示根到该节点表示的字符串在树中的最长后缀的节点编号。
[图没了]
对于红色箭头指向的节点,由于 \(\texttt{abc}\) 在树中的最长后缀是 \(\texttt{bc}\),所以红色节点的 fail 指向蓝色节点。特殊的,如果一个节点在树中找不到后缀,那么让它的 fail 指向根节点。
在求 fail 时可以这样写:
inline void Build_Fail(void){
std::queue <int> q;
while(!q.empty())
q.pop();
for(rr int i=0;i<26;++i){ // 第一圈的节点肯定没有 fail
if(trie[0].next[i]){
q.push(trie[0].next[i]);
}
}
while(!q.empty()){
int i=q.front();
q.pop();
for(rr int j=0;j<26;++j){
if(!trie[i].next[j]){
trie[i].next[j]=trie[trie[i].fail].next[j];//Trie 中没有这个点 特殊处理
continue;
}
trie[trie[i].next[j]].fail=trie[trie[i].fail].next[j];//类似于 KMP 的思想
q.push(trie[i].next[j]);//压入 继续更新
}
}
return;
}
而在匹配的时候,文本串直接在 trie 上走就好了。设走完 \(i\) 次后到了一个点 \(j\),那么说明以 \(T_i\) 结尾的文本串就在 \(trie_j\) 上跳 fail 就好了。
inline int Query(char *s){
int len=strlen(s);
int p=0;
int ans=0;
for(rr int i=0;i<len;++i){
int j=trie[p].next[s[i]-'a'];
while(j&&~trie[j].cnt){ // 防止重复计算 & 保证时间复杂度
ans+=trie[j].cnt;
trie[j].cnt=-1;
j=trie[j].fail;
}
p=trie[p].next[s[i]-'a'];
}
return ans;
}
优化 - 拓扑建图
对于问题的一个加强版,要求每个模式串在文本串中的出现次数。
这个时候,不能用经典问题中的 给 trie 上节点标 \(-1\) 来保证复杂度了。
因为每个节点都有一个唯一的 fail,于是将每个节点和它的 fail 连边,可以建成一个 DAG。在这个 DAG 上拓扑排序就好了。
# include <bits/stdc++.h>
# define rr
const int N=200010,INF=0x3f3f3f3f;
struct Node{
int fail;
int next[26];
}trie[N];
int endflag[N];
char S[N*10];
int cnt;
char c[N];
int id[N],du[N],v[N];
int n;
int ans[N];
inline int read(void){
int res,f=1;
char c;
while((c=getchar())<'0'||c>'9')
if(c=='-')f=-1;
res=c-48;
while((c=getchar())>='0'&&c<='9')
res=res*10+c-48;
return res*f;
}
inline void Insert(char *s,int x){
int p=0,len=strlen(s);
for(rr int i=0;i<len;++i){
if(!trie[p].next[s[i]-'a']){
trie[p].next[s[i]-'a']=++cnt;
}
p=trie[p].next[s[i]-'a'];
}
if(!endflag[p]){
endflag[p]=x;
}
id[x]=endflag[p];
return;
}
inline void GetFail(void){
std::queue <int> q=std::queue <int>();
for(rr int i=0;i<26;++i){
if(trie[0].next[i]){
q.push(trie[0].next[i]);
}
}
while(!q.empty()){
int x=q.front();
q.pop();
for(rr int i=0;i<26;++i){
if(!trie[x].next[i]){
trie[x].next[i]=trie[trie[x].fail].next[i];
continue;
}
trie[trie[x].next[i]].fail=trie[trie[x].fail].next[i];
++du[trie[trie[x].next[i]].fail];
q.push(trie[x].next[i]);
}
}
return;
}
inline void query(void){
int p=0,len=strlen(S);
for(rr int i=0;i<len;++i){
p=trie[p].next[S[i]-'a'];
++v[p]; // 跳到点 p 的次数
}
return;
}
inline void topsort(void){
std::queue <int> q=std::queue <int> ();
for(rr int i=1;i<=cnt;++i){
if(!du[i]){
q.push(i);
}
}
while(!q.empty()){
int i=q.front();
q.pop();
ans[endflag[i]]=v[i];
--du[trie[i].fail];
v[trie[i].fail]+=v[i];
if(!du[trie[i].fail]){
q.push(trie[i].fail);
}
}
}
int main(void){
n=read();
for(rr int i=1;i<=n;++i){
scanf("%s",c);
Insert(c,i);
}
scanf("%s",S);
GetFail();
query();
topsort();
for(rr int i=1;i<=n;++i){
printf("%d\n",ans[id[i]]);
}
return 0;
}