P5357【模板】AC自动机(二次加强版)

\(P5357\) 【模板】\(AC\) 自动机(二次加强版)

一、题目描述

给你一个文本串 \(S\)\(n\) 个模式串 \(T_{1 \sim n}\),请你分别求出每个模式串 \(T_i\)​在 \(S\) 中出现的次数。

二、对加强版进行改造

因为是什么二次加强版,所以大家先去做一下 加强版 吧,做法差不多。

好了,看到这里大家都一定做过加强版了吧,那么这道题的做法也是差不多的: 我们这一次不需要求出现最多的字符串啦,直接将\(cnt\)数组输出就好了!(应该都知道\(cnt\)数组是什么吧,就是统计每个模式串在文本串出现多少次的数组

重复的单词有没有影响啊!有啊!对于加强版这一次重复的单词就会有影响啦,怎么办?

这道题有相同字符串要统计,设当前字符串是第\(x\)个,我们用\(family[x]\)数组存当前字符串在\(Trie\)中的那个位置输入模式串序号,最后把\(cnt[family[i]]\)输出就\(OK\)了。另外\(id\)只在第一次赋值时变化,其他都不变。

本题思路很简单,如果你做过加强版的话。
这个思路很好搞,就是简单统计出现次数,然后输出。
不过如果你直接交会发现\(TLE\)
我当时就是非常高兴的把加强版的代码改了改交了上去:

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream>
using namespace std;
const int N = 2 * 1e5 + 10;
const int M = 2 * 1e6 + 10;

char s[N], T[M];
int n;
int tr[N][26], idx, ne[N];
int id[N];
int cnt[N];
int family[N];

void insert(char *s, int x) {
    int p = 0;
    for (int i = 0; s[i]; i++) {
        int t = s[i] - 'a';
        if (!tr[p][t]) tr[p][t] = ++idx;
        p = tr[p][t];
    }

    if (!id[p]) id[p] = x; // id[p]记录的是首个入驻的模式串号x
    family[x] = id[p];     // 将所有最终位置是p号节点,也就是重复模式串,都划归到family[x]这个首次入驻模式串x为同一家族
}

int q[N], hh, tt = -1;
void bfs() {
    for (int i = 0; i < 26; i++)
        if (tr[0][i]) q[++tt] = tr[0][i];

    while (hh <= tt) {
        int p = q[hh++];
        for (int i = 0; i < 26; i++) {
            int t = tr[p][i];
            if (!t)
                tr[p][i] = tr[ne[p]][i];
            else {
                ne[t] = tr[ne[p]][i];
                q[++tt] = t;
            }
        }
    }
}

void query(char *s) {
    int p = 0;
    for (int i = 0; s[i]; i++) {
        p = tr[p][s[i] - 'a'];
        for (int j = p; j; j = ne[j])
            if (id[j]) cnt[id[j]]++;
    }
}

int main() {
    //加快读入
    ios::sync_with_stdio(false), cin.tie(0);
    cin >> n;
    for (int i = 1; i <= n; i++) {
        cin >> s;
        insert(s, i);
    }
    bfs();

    cin >> T;
    query(T);

    for (int i = 1; i <= n; i++) printf("%d\n", cnt[family[i]]);
    return 0;
}

三、持续优化

那么我们现在得到的算法是直接跑自动机跳\(ne\)树,然后每跳到一个标记点就计数加一。

考虑对其进行优化:

当匹配到单词的时,我们会不断地去跳\(ne\),同一个点会可能被跳多次。
那么我们可以想一下把\(ne\)指针单独建出来他是个什么样子的。

它就是一棵树。

那么对于这一棵树,我们每次匹配时都会去更新它的父亲节点(\(ne\)树),那么对于树上的一条链,每一个子节点也有父子关系,他们会有共同的祖先。对于一对被遍历过的父子节点,它们的共同祖先显然会被父亲跳一次再被儿子跳一次,如果能够减少跳的次数,同时不丢失贡献,那么我们就能降低复杂度,从而完成本题。

那么我们思考一下,我们在跑自动机时如果先不跳\(ne\),而是单纯的跑\(trie\)树,是比连跑带跳(\(ne\))复杂度小不少的。那么跑完\(trie\)树,我们得到的是什么?

我们得到的是文本串在自动机上跑过的痕迹(脚印),我们也就得到了每个节点(不跳\(ne\))被遍历的次数,在这些节点中,我们可以拿出来再更新\(ne\)

这时我们应该想一下既然都拿出来了,有没有什么方法能优化更新?
这样我们就需要思考\(ne\)树的性质
我们思考一下\(ne\)指针的建立:当前节点的 / 父亲节点的 / \(ne\)指针指向节点的 / 子节点 。
\(ne\)指针在树上跳时是一定向上跳的,最下面的节点会更新上面的父亲节点。
那么一个点被遍历的次数就是:\(trie\)树上遍历次数 + \(ne\)树上子节点被遍历次数

而子节点被遍历次数又取决于其\(trie\)树上遍历次数和自身子节点的个数。
那么最下面的点是不需要被其他点通过\(ne\)更新的

如果最下面的点更新过自己的父亲节点,那么它的父亲节点也就是次深的点就成了刚才最下面的点的状态。
而且一个节点只会被更新一次。

于是我们就得到了\(trie\)树的更新方法, 通过拓扑序 更新\(ne\)树,从底往上不断累加,最后输出结果。

拓扑序优化递推版本

#include <bits/stdc++.h>
using namespace std;
const int N = 200010;
const int M = 2000010; //文本串长度

int f[N];
char s[N];
char T[M];
int tr[N][26], idx, ne[N], id[N];

void insert(char *s, int x) {
    int p = 0;
    for (int i = 0; s[i]; i++) {
        int t = s[i] - 'a';
        if (!tr[p][t]) tr[p][t] = ++idx;
        p = tr[p][t];
    }
    id[x] = p;
}

//构建AC自动机
int q[N], hh, tt = -1;
void bfs() {
    for (int i = 0; i < 26; i++)
        if (tr[0][i]) q[++tt] = tr[0][i];

    while (hh <= tt) {
        int t = q[hh++];
        for (int i = 0; i < 26; ++i) {
            if (tr[t][i]) {
                ne[tr[t][i]] = tr[ne[t]][i];
                q[++tt] = tr[t][i];
            } else
                tr[t][i] = tr[ne[t]][i];
        }
    }
}

void query(char *s) {
    int p = 0;
    for (int i = 0; s[i]; i++) { // 枚举文本串每一个字符
        int t = s[i] - 'a';      // 字符映射的数字t,可以理解为边
        p = tr[p][t];            // 走进去,到达的位置替换p
        f[p]++;                  // 标识此位置有人走过,记录走的次数
    }
}

int main() {
    //加快读入
    ios::sync_with_stdio(false), cin.tie(0);
    int n;
    cin >> n;
    for (int i = 1; i <= n; i++) {
        cin >> s;
        insert(s, i);
    }
    //构建AC自动机
    bfs();
    //文本串
    cin >> T;
    query(T);
    for (int i = idx; i; i--) f[ne[q[i]]] += f[q[i]]; //一路向上,计算叠加值
    //输出
    for (int i = 1; i <= n; i++) printf("%d\n", f[id[i]]);
    return 0;
}
posted @ 2022-05-13 14:26  糖豆爸爸  阅读(72)  评论(0编辑  收藏  举报
Live2D