P3796 【模板】AC 自动机(加强版)
\(P3796\) 【模板】\(AC\)自动机(加强版)
作为模板\(2\),这道题的解法也是十分的经典。
我们先来分析一下题目:输入和模板\(1\)一样
对比简单版
\(P3808\) 【模板】\(AC\)自动机(简单版)
给定 \(n\) 个模式串 \(s_i\) 和一个文本串 \(t\),求有多少个不同的模式串在文本串里出现过。两个模式串不同当且仅当他们编号不同。
\(P3796\) 【模板】\(AC\) 自动机(加强版)
有 \(N\) 个由小写字母组成的模式串以及一个文本串 \(T\)。每个模式串可能会在文本串中出现多次。你需要找出哪些模式串在文本串 \(T\) 中出现的次数最多。
我们发现,加强版有两个要求:
-
求出现次数最多的次数
-
求出现次数最多的模式串
明显,我们如果 统计出每一个模式串在文本串出现的次数,那么这道题就变得十分简单了,那么问题就变成了如何统计每个模式串出现的次数。
\(AC\)自动机
首先题目统计的是 出现次数最多的字符串,所以有重复的字符串是没有关系的。(因为后面的会覆盖前面的,统计的答案也是一样的)
那么我们就将标记模式串的\(flag\)设为当前是第几个模式串。就是下面插入\(insert\)时的变化:
cnt[p]++;
变为
id[p] = x;//x表示该字符串是第x个输入的
求\(Fail\)指针没有变化,原先怎么求就怎么求。
查询
我们开一个数组\(cnt\),表示第\(i\)个字符串出现的次数。
因为是重复计算,所以不能标记为\(-1\)了。
我们每经过一个点,如果有 模式串标记号,就将\(cnt[\)模式串标记号\(]++\)。然后继续跳\(ne\),原因上面说过了。
这样我们就可以将每个模式串的出现次数统计出来。剩下的大家应该都会!
实现代码
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream>
using namespace std;
const int N = 1000010;
char s[150 + 10][70 + 10]; //模式串,第一维是多少个,第二维是具体的字符
char T[N]; //文本串 ,长度最大10^6
int n; //模式串数量
int cnt[N]; //每个模式串出现的次数
int tr[N][26], idx; // Trie树
int id[N]; // 节点号-mapping->模式串
void insert(char *s, int x) {
int p = 0;
for (int i = 0; s[i]; i++) {
int t = s[i] - 'a';
if (!tr[p][t]) tr[p][t] = ++idx;
p = tr[p][t];
}
id[p] = x; //记录:节点号-mapping->模式串
}
//构建AC自动机
int q[N], ne[N];
void bfs() {
int hh = 0, tt = -1;
for (int i = 0; i < 26; i++)
if (tr[0][i]) q[++tt] = tr[0][i];
while (hh <= tt) {
int p = q[hh++];
for (int i = 0; i < 26; i++) {
int t = tr[p][i];
if (!t)
tr[p][i] = tr[ne[p]][i];
else {
ne[t] = tr[ne[p]][i];
q[++tt] = t;
}
}
}
}
//查询字符串s在AC自动机中出现的次数
void query(char *s) {
int p = 0;
for (int i = 0; s[i]; i++) {
p = tr[p][s[i] - 'a'];
for (int j = p; j; j = ne[j])
if (id[j]) cnt[id[j]]++; //如果有模式串标记,更新出现次数
}
}
int main() {
//加快读入
ios::sync_with_stdio(false), cin.tie(0);
while (cin >> n && n) {
//每次清空
memset(tr, 0, sizeof tr);
memset(cnt, 0, sizeof cnt);
memset(ne, 0, sizeof ne);
memset(id, 0, sizeof id);
idx = 0;
for (int i = 1; i <= n; i++) {
cin >> s[i];
insert(s[i], i);
}
bfs();
cin >> T;
query(T);
int Max = 0;
for (int i = 1; i <= n; i++) Max = max(cnt[i], Max); //最后统计答案
printf("%d\n", Max);
//最大值可能很多个模式串匹配到,需要在获取完最大值后,再次循环输出符合最大值条件的所有模式串
for (int i = 1; i <= n; i++)
if (cnt[i] == Max) printf("%s\n", s[i]);
}
return 0;
}