Luogu P3808 【模板】AC自动机(简单版)
AC自动机(Aho-Corasick automaton)是一种优化的多模式串匹配的算法,它像是trie树和KMP的结合体。
这个算法分为三部分:建立trie树,求fail,匹配。
假设有cod,cos,ost,op几个单词。(这图画了好久好久好久)
在普通trie树的基础上,像kmp一样处理出模式串的失配函数。(没有画出的,fail指向根0)
建立trie树
和普通的trie相同。
void insert(char *s) { int len = strlen(s); int u = 0; for(int i = 0; i < len; i++) { int v = s[i]-'a'; if(!trie[u].son[v]) trie[u].son[v] = ++num; u = trie[u].son[v]; } trie[u].fin++; }
求fail
用bfs的方法,每次遍历'a'-'z',并将存在的节点压入队列。
- trie上第一层fail = 0
- x的son[i]的fail = x的fail的son[i]
- 如果现在x不存在son[i],那么x的son[i] = x的fail的son[i]
根据上述规则,得到以下代码
void getf() { queue <int> q; for(int i = 0; i < 26; i++) { if(trie[0].son[i]) { trie[trie[0].son[i]].fail = 0; q.push(trie[0].son[i]); } } while(!q.empty()) { int u = q.front(); q.pop(); for(int i = 0; i < 26; i++) { if(trie[u].son[i]) { trie[trie[u].son[i]].fail = trie[trie[u].fail].son[i]; q.push(trie[u].son[i]); } else trie[u].son[i] = trie[trie[u].fail].son[i]; } } }
匹配
本题中,要记录文本串中出现的模式串的数量。
和kmp类似。每走到一个节点,将以这个节点为结尾的数量计入答案,并将模式串不断跳到fail。
因为可能会形成环,需要记录是否访问过当前节点,防止重复。
int query(char *s) { int len = strlen(s); int u = 0; int ans = 0; for(int i = 0; i < len; i++) { int v = s[i]-'a'; u = trie[u].son[v]; for(int j = u; j && !trie[j].vis; j = trie[j].fail) { ans += trie[j].fin; trie[j].vis = true; } } return ans; }
还有什么暂时想不到了
完整代码如下
#include<cstdio> #include<iostream> #include<cmath> #include<cstring> #define MogeKo qwq #include<queue> using namespace std; const int maxn = 1e6+10; int n,num; char s[maxn]; struct node { int son[26]; int fin,fail; bool vis; } trie[maxn]; void insert(char *s) { int len = strlen(s); int u = 0; for(int i = 0; i < len; i++) { int v = s[i]-'a'; if(!trie[u].son[v]) trie[u].son[v] = ++num; u = trie[u].son[v]; } trie[u].fin++; } void getf() { queue <int> q; for(int i = 0; i < 26; i++) { if(trie[0].son[i]) { trie[trie[0].son[i]].fail = 0; q.push(trie[0].son[i]); } } while(!q.empty()) { int u = q.front(); q.pop(); for(int i = 0; i < 26; i++) { if(trie[u].son[i]) { trie[trie[u].son[i]].fail = trie[trie[u].fail].son[i]; q.push(trie[u].son[i]); } else trie[u].son[i] = trie[trie[u].fail].son[i]; } } } int query(char *s) { int len = strlen(s); int u = 0; int ans = 0; for(int i = 0; i < len; i++) { int v = s[i]-'a'; u = trie[u].son[v]; for(int j = u; j && !trie[j].vis; j = trie[j].fail) { ans += trie[j].fin; trie[j].vis = true; } } return ans; } int main() { scanf("%d",&n); for(int i = 1; i <= n; i++) { scanf("%s",s); insert(s); } getf(); scanf("%s",s); printf("%d",query(s)); return 0; }