AC自动机
AC自动机
它具体而言是实现字符串多模匹配的
相比KMP,它能解决找出多个字符串在文章中出现的位置(KMP只是1-1)
你可以认为它是KMP(思想)+ Trie字典树
不过和KMP一样,在查找时它加入了fail数组,fail[i]记录以i为结尾的后缀在Trie中其他字符串上的最长前缀
听起来很绕,但可以理解一下优化原理:
前面的i位置都匹配成功,说明模式串以i结尾后缀与文章这段相同
下一个位置i+1失配了,按暴力做法应从文章下一个位置,Trie树根开始匹配
但我们已经找到以i为结尾的后缀在Trie中其他字符串上的最长前缀(fail数组),于是可以从那个前缀的下一个位置和文章中开始失配位置匹配(找出的前缀自然与后缀相等,也就与文章中匹配位置相等,相当于那个字符串开头已匹配,就不用重复比较了)
这就也是KMP的优化原理!(学完了AC自动机才透彻理解KMP)
代码实现
step 1:插入(insert):和Trie一样
void insert()
{
int u = root; // 这里为方便处理root的fail,root为1
for(int i = 0; i < strlen(s); i++)
{
if(!son[u][s[i] - 'a']) son[u][s[i] - 'a'] = ++idx; // 新建节点
u = son[u][s[i] - 'a'];
}
cnt[u]++; // 末尾做标记
}
step 2:最重要的——建fail指针
根据定义可知fail指针一定往上指,那么根节点子节点的fail指针一定为根
剩下有点麻烦,但利用了BFS的思想:
当前要处理的节点为i
-
若要找的前缀(由定义知要找以i为结尾的后缀在Trie中其他字符串上的最长前缀)包含i的父节点fa,那么前缀的结尾(与i对应的点)一定是fa的fail指向的节点的子节点(fa与fail[fa]对应,那么fa是i的父亲,fail[fa]也是fail[i]的父亲)
-
若要找的前缀不包含i的父节点fa,那么前缀的结尾fa的fail指向的节点的子节点一定是根的子节点,所以fa无对应点,fail[fa]指root,那fail[i]还是fa的fail指向的节点的子节点
-
若找不到前缀,那fail[i]直接指root,但可把空节点全指root而不用特判
综上,\(fail[i] = son[fail[fa]][i]\),写出代码——
void getfail()
{
fail[root] = 0; // 初始化
queue<int> q; // 队列,实现BFS,每次扩展它的子节点
for(int i = 0; i < 26; i++)
{
son[0][i] = root; // 0的孩子为root,方便处理
if(son[root][i]) fail[son[root][i]] = root, q.push(son[root][i]); // 根节点的子节点,特殊情况
else son[root][i] = son[fail[root]][i];
} // 这里是细节,为空的节点可直接指向它的父节点的fail节点的相同子节点,就算fail[root]也指0,而0的子节点全是root
while(!q.empty())
{
int u = q.front(); q.pop();
for(int i = 0; i < 26; i++) // 枚举26个节点
if(son[u][i]) // 若此节点不为空(此时还没更新为root)
{
int p = fail[u];
fail[son[u][i]] = son[p][i]; // 如上3种情况
// 这里若fail应指root,那么它父亲的fail一定没有与它相同子节点(否则就为那个节点),指向空节点,而已经处理过空节点指root
q.push(son[u][i]); // 存在才有扩展可能,加入队列
}
else son[u][i] = son[fail[u]][i];
// 根的子节点中空节点指root,则所有空节点指上层空节点,也就指root
}
}
细节有亿点点多,错一处都不行
step 3:查询(开始匹配文章)
这里主要是根据不同题目要特殊处理
有些题只让你算出现了几个字符串,这时经过cnt(模式串为标记)就要把它赋值为-1,避免一个串出现多次而重复计算
而要求每个串出现几次,则cnt就保持原值,每扫到一次就加上它的值
此时还要存一下每个结束点对应哪些字符串(注意重复情况)
原理:
因为前面也说了以i为结尾的后缀等于以fail[i]为结尾前缀,所以若此时到i都匹配,则fail[i]前也匹配,可跳转到fail[i]继续匹配
int query()
{
int tmp, u = root; // 从root开始找
for(int i = 0; i < strlen(t); i++)
{
tmp = son[u][t[i] - 'a']; // 当前在u的与文章相同子节点
while(tmp > 1 && cnt[tmp] != -1) // tmp不为根(不为空)且cnt未被找过
{
ans += cnt[tmp]; // 可能有重复,ans要加上cnt的值
cnt[tmp] = -1; // 此只求有几个模式串出现,cnt要改为-1
tmp = fail[tmp]; // 跳fail
}
u = son[u][t[i] - 'a']; // Trie中向下一个
}
return ans;
}
upd:2022.11.23
你以为匹配就这么简单?(我当时真这么以为)
这是在每个串只看有没有出现,不用重复匹配,因此复杂度线性
但在求出现次数的题目中,需要不断跳 \(fail\),极限下(如 \(aa\cdots a\))一个点最多经过 \(O(slen)\) 次,会被卡成 \(O(slen\times\sum tlen)\)
优化:
把 \(fail\) 指针看成边,发现每个节点最多一个出度,\(fail\) 指针构成了有向无环图,姑且叫它 \(fail\) 树
那一个点会对它的所有 \(fail\) 指针直接或间接指向的点产生贡献……
不就是从入度为 \(0\) 的点开始一直传递到出度为 \(0\) 的点吗?拓扑排序!
技巧:
在 \(bfs\) 求 \(fail\) 时,已经从上往下遍历,且 \(fail\) 只会往上指
\(bfs\) 时队列中存的顺序就是一个反向的拓扑序
用数组模拟队列,反向累加即可
inline void pipei()
{
int u = root, len = s.length();
for(reg int i = 0; i < len; ++i) u = son[u][s[i] - 'a'], ++ans[u];
for(reg int i = tt; i > 0; --i) ans[fail[q[i]]] += ans[q[i]];
}
应用
1. 多个串之间出现次数
这让我们求每个模式串在整个Trie中出现次数
而我们发现,每个串的结尾若对应fail不为空,
说明有一个串以fail指向的节点为结尾的前缀和 以当前串以结尾为结束点的后继相等——
不就是这个串本身吗?这个串出现了1次
所以不断跳fail,把当前串结尾的fail所有间接或直接的节点算出个数,就是答案
这里我们把fail指针组成的树叫fail树——求指向的所有间接或直接的节点不就是求以它为根的子树大小?
那就好求了,可以常规的树上递归求解
不过这里有个技巧:之前我们求fail时已经按照树的深度从小到大对它BFS了一遍,而子树大小的状态恰好是由深度大转移为深度小
所以我们用之前的队列反着扫一遍,每次将当前节点转移至它的fail节点,就求出来了
代码如下——
#include<bits/stdc++.h>
using namespace std;
int n, son[1000010][27], cnt[1000010], root = 1, idx = 1, fail[1000010], tot, ans[210], siz[1000010], mp[210], que[1000010], hh, tt;
string s[210];
char t[1000210];
void insert(int id) // 常规插入
{
int u = root;
for(int i = 0; i < s[id].length(); i++)
{
if(!son[u][s[id][i] - 'a']) son[u][s[id][i] - 'a'] = ++idx;
u = son[u][s[id][i] - 'a'];
siz[u]++;
}
cnt[u]++, mp[id] = u; // 注意记下每个字符串结尾编号
}
void getfail() // 正常建fail
{
fail[root] = 0;
hh = tt = 0;
for(int i = 0; i < 26; i++)
{
son[0][i] = root;
if(son[root][i]) fail[son[root][i]] = root, que[++tt] = son[root][i];
else son[root][i] = son[fail[root]][i];
}
while(hh < tt) // 要存下队列,不用STL,直接用数组加减下标
{
int u = que[++hh];
for(int i = 0; i < 26; i++)
if(son[u][i])
{
int p = fail[u];
fail[son[u][i]] = son[p][i];
que[++tt] = son[u][i];
}
else son[u][i] = son[fail[u]][i];
}
}
int main()
{
scanf("%d", &n);
for(int i = 1; i <= n; i++)
{
cin >> s[i];
insert(i);
for(int j = 0; j < s[i].length(); j++) t[++tot] = s[i][j];
t[++tot] = '{';
}
getfail();
for(int i = tt; i > 0; i--) siz[fail[que[i]]] += siz[que[i]]; // 如上面所说反着扫队列
for(int i = 1; i <= n; i++) printf("%d\n", siz[mp[i]]); // 输出以字符串结尾为根的fail树子树大小
return 0;
}
2. 多串动态匹配
\(solution\):
多串,肯定建AC自动机,但每次暴力从头匹配肯定不行
模拟一下匹配跳 \(fail\) 的过程,发现删除一个串时,它前面的匹配情况不会被影响
因此直接边匹配边删除,用栈维护一下每次匹配到的自动机上的节点编号,删除时直接从上一个未删除的节点已匹配的状态开始
注意删除掉整个前缀后要从根开始
代码:
getfail(), noww = root; // AC自动机
for(reg int i = 0; i < slen; ++i)
{
noww = son[noww][s[i] - 'a'], stk1.push(noww), stk2.push_back(s[i]);
if(cnt[noww]) // 找到可删除的串了
{
for(reg int j = 1; j <= len[noww]; ++j) stk1.pop(), stk2.pop_back(); // 弹出
if(stk1.empty()) noww = root; // 可能一个前缀被删完,从头开始
else noww = stk1.top(); // 从上一个还没删的节点开始匹配
}
}
cout << stk2;
3. 完全拼接问题
\(solution\):
朴素的想法是把 \(S\) 串们建AC自动机,然后在 \(T\) 串上暴力跳 \(fail\),统计能接哪些长度的串并记录能否到达此前缀
\(95 pts\):
for(int i = 0; i < strlen(t); i++)
{
tmp = son[u][t[i] - 'a'];
while(tmp > 1)
{
if(cnt[tmp] > 0)
if(book[i + 1 - strlen(s[num[tmp]])])
{
book[i + 1] = flag = 1; // 这一位找到了能接的
break;
}
tmp = fail[tmp]; // 暴力跳fail
}
u = son[u][t[i] - 'a'];
}
for(int i = strlen(t); i >= 0; i--)
if(book[i]) return i;
但T了……
优化:
暴力跳 \(fail\) 会被卡,但 \(fail\) 树上按拓扑序 \(dp\) 是个好东西
\(|S|\) 不超过 \(20\),直接状态压缩预处理AC自动机上以每个节点为末尾能接在前面的串长,即为 \(len[i]\)
还是同上的方法,不过这次是出度为0传递给入度为0,从上往下,按队列顺序处理即可
而且根据上面,\(|T|\) 当前位置后超过 \(20\) 位就没用了,仍然状压
记 \(book\) 表示 \(T\) 当前位置后 \(20\) 位是否合法,若二进制下 \(book\) 第 \(i-1\) 位为 \(1\) 则当前位置后第 \(i\) 位能被拼接
每次转移 \(book << 1 | f[i-1]\)(向后一位,加上第 \(i-1\) 位)
当且仅当 \(book\) & \(len[noww]\) 为 \(true\),\(f[i]=1\)
(\(len[noww]\) 中有对应长度的能接上)
\(AC code\):
for(reg int i = 1; i <= tt; ++i) len[q[i]] |= len[fail[q[i]]];
for(reg int i = 1; i <= m; ++i)
{
scanf("%s", t + 1), tlen = strlen(t + 1);
noww = root, book = 0, ans = 0, f.reset(), f[0] = 1;
for(reg int j = 1; j <= tlen; ++j)
{
noww = son[noww][t[j] - 'a'];
book = ((book << 1) | f[j - 1]) & ((1ll << mx) - 1);
f[j] = (book & len[noww]) != 0;
}
for(reg int j = tlen; j > 0; --j)
if(f[j]) {ans = j; break;}
print(ans), putchar('\n');
}
4. AC 自动机上 DP
P3041 [USACO12JAN] Video Game G
想到 DP,设 \(f(i,j)\) 为当前填到长度为 \(i\),此时在 AC 自动机的 \(j\) 节点处的最大得分
这里建 AC 自动机,但注意,如果包含一个大串,这个串可能会包含很多小串
先根据 fail 树,求出如果到了某个节点,能得到的分数,即以它结尾的匹配数
然后就可以 DP 了
时间复杂度 \(O(n|S|l)\)
#include<bits/stdc++.h>
#define pb push_back
using namespace std;
int n, k, idx = 1, f[1010][310], trie[310][3], ans, cur, root = 1, q[310], hh = 1, tt, fail[310], cnt[310], book[2][310];
string s;
void insert()
{
int len = s.length(), u = root;
for(int i = 0; i < len; ++i)
{
if(!trie[u][s[i] - 'A']) trie[u][s[i] - 'A'] = ++idx;
u = trie[u][s[i] - 'A'];
}
++cnt[u];
}
void getfail()
{
for(int i = 0; i < 3; ++i)
if(trie[root][i]) fail[trie[root][i]] = root, q[++tt] = trie[root][i];
else trie[root][i] = root;
while(hh <= tt)
{
int u = q[hh++];
for(int i = 0; i < 3; ++i)
{
if(trie[u][i]) fail[trie[u][i]] = trie[fail[u]][i], q[++tt] = trie[u][i];
else trie[u][i] = trie[fail[u]][i];
}
}
for(int i = 1; i <= tt; ++i) cnt[q[i]] += cnt[fail[q[i]]];
}
int main()
{
ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
cin >> n >> k;
for(int i = 1; i <= n; ++i) cin >> s, insert();
getfail();
book[0][root] = cur = 1;
for(int i = 1; i <= k; ++i, cur ^= 1)
{
for(int j = 1; j <= idx; ++j)
if(book[cur ^ 1][j])
for(int v = 0; v < 3; ++v)
f[i][trie[j][v]] = max(f[i][trie[j][v]], f[i - 1][j] + cnt[trie[j][v]]), book[cur][trie[j][v]] = 1;
memset(book[cur ^ 1], 0, sizeof(book[cur ^ 1]));
}
for(int i = 1; i <= idx; ++i)
if(book[cur ^ 1][i]) ans = max(ans, f[k][i]);
cout << ans;
return 0;
}
5. 拼接后的串的匹配
拼接 \(s_i,s_j\),不好处理
但如果枚举中间的端点,把 \(t\) 断开
则可以看作 \(t\) 前半部分的后缀与 \(s_i\) 匹配,\(t\) 后半部分的前缀与 \(s_j\) 匹配
把贡献拆开,等于枚举端点,求 \(t\) 前半部分后缀匹配的串数及\(t\) 后半部分的前缀匹配的串数
把 \(s\) 串们建 AC 自动机
先正着做,求出前半部分,再把 \(s\) 串倒着插入另一自动机,\(t\) 倒着匹配
断点处方案数相乘即可
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll N = 200010;
ll n, ans, qzh[N], m, hzh[N];
string t, s[N];
struct ACAM
{
ll trie[N][27], idx, root, cnt[N], fail[N], q[N], hh, tt;
void clear()
{
memset(trie, 0, sizeof(trie)), memset(cnt, 0, sizeof(cnt));
memset(fail, 0, sizeof(fail)), root = idx = hh = 1, tt = 0;
}
void insert(string str)
{
ll u = root, len = str.length();
for(int i = 0; i < len; ++i)
{
if(!trie[u][str[i] - 'a']) trie[u][str[i] - 'a'] = ++idx;
u = trie[u][str[i] - 'a'];
}
++cnt[u];
}
void getfail()
{
for(int i = 0; i < 26; ++i)
if(trie[root][i]) q[++tt] = trie[root][i], fail[trie[root][i]] = root;
else trie[root][i] = root;
while(hh <= tt)
{
ll u = q[hh++];
for(int i = 0; i < 26; ++i)
if(trie[u][i]) fail[trie[u][i]] = trie[fail[u]][i], q[++tt] = trie[u][i];
else trie[u][i] = trie[fail[u]][i];
}
for(int i = 1; i <= tt; ++i) cnt[q[i]] += cnt[fail[q[i]]];
}
}ac;
int main()
{
ios::sync_with_stdio(false), cin.tie(0);
cin >> t >> n; ac.clear(), m = t.length();
for(int i = 1; i <= n; ++i) cin >> s[i], ac.insert(s[i]);
ac.getfail();
for(int u = ac.root, i = 0; i < m; ++i)
{
u = ac.trie[u][t[i] - 'a'];
qzh[i] = ac.cnt[u];
}
ac.clear();
for(int i = 1; i <= n; ++i) reverse(s[i].begin(), s[i].end()), ac.insert(s[i]);
ac.getfail();
for(int u = ac.root, i = m - 1; i >= 0; --i)
{
u = ac.trie[u][t[i] - 'a'];
if(i) ans += qzh[i - 1] * ac.cnt[u];
}
printf("%lld", ans);
return 0;
}