Loading

自动机入门——AC 自动机

AC 自动机

1 算法简介

AC 自动机是一个以 Trie 为基础结合 KMP 的思想建立的。在 AC 自动机中,每一个状态代表着某个模式串的前缀,而整个 DFA 的结构其实是所有模式串的 Trie 树。

而 AC 自动机可以处理这样一个问题:多模式匹配。即给你若干个模式串和一个主串,要求我们对每一个字符串和主串进行匹配。

我们肯定不能做多次 KMP,所以我们有了 AC 自动机。

2 算法讲解

2.1 状态设计

struct node{
    int ch[26],end,fail;
};

其中,ch 数组是 Trie 树的指针,end 是判断这个状态为多少串的节点,fail 指针是后缀链接,指向最长真后缀对应的状态。比如这个动图(其中的黄线是后缀链接):

其中 \(2\) 节点的指针画错了,应该是指向 \(0\)

2.2 插入

因为自身结构就是 Trie 的结构,所以 AC 自动机的插入和 Trie 树的插入是一模一样的。

代码:

	inline void insert(char *s){
        int now=0,len=strlen(s);
        for(int i=0;i<len;i++){
            int k=s[i]-'a';
            if(!p[now].ch[k]) p[now].ch[k]=++tot;
            now=p[now].ch[k];
        }
        p[now].end++;
    }

2.3 建立

在这里,我们需要建立 AC 自动机,Trie 树已经建好了,我们的目的是构建失配指针 fail 。暴力构建的话就是取其父节点,然后不断跳 fail ,直到调到一个状态,它有一条有相同字符的出边。但是这样时间复杂度不优,怎样优化?

我们先上代码:

inline void build(){
    queue<int> q;
    for(int i=0;i<26;i++) if(p[0].ch[i]) q.push(p[0].ch[i]);
    while(q.size()){
        int top=q.front();q.pop();
        for(int i=0;i<26;i++){
            if(p[top].ch[i]) p[p[top].ch[i]].fail=p[p[top].fail].ch[i],q.push(p[top].ch[i]);
            else p[top].ch[i]=p[p[top].fail].ch[i];
        }
    }
}

这里我们通过对跳 fail 的路径进行压缩,如果子节点存在,那就好说,我们把 fail 直接连过来就可以,但是如果不存在,我们就采用路径压缩,把其 fail 的子节点连过来,这样就完成了路径压缩。

放图:

2.4 查询

我们接下来分析查询函数 query ,这个函数将实现多模式匹配。我们直接放代码:

inline int query(char *t){
    int now=0,res=0,len=strlen(t);
    for(int i=0;i<len;i++){
        now=p[now].ch[t[i]-'a'];
        for(int j=now;j&&~p[j].end;j=p[j].fail) res+=p[j].end,p[j].end=-1;
    }
    return res;
}

什么意思?注意到我们对于 \(t\) 的每一个前缀,查询一下能匹配这个前缀的字符串,方法就是我们在 Trie 上走,然后对于每一个状态跳后缀链接,累加所有是结束节点的地方,注意因为可能重复调到同一个节点,所以我们要对所有的节点进行标记,这里 ~(-1)=0 ,这是一个很妙的打标记方式。

注意这 ch 是我们建出来的 Trie 图,事实上,也只有 Trie 图才能完成这个操作。如果是平常的 Trie 树,可能最终会无路可走。

如果 Trie 树走到头了,会回到 \(0\) 号节点

2.5 时间复杂度分析

设所有模式串的总长为 \(N\),主串的长度为 \(m\) ,那么建立 Trie 树的复杂度为 \(O(N)\),查询时最坏情况是所有的后缀链接都被便利,为 \(O(N)\) ,再加上主串走 Trie 树,复杂度为 \(O(m)\)

所以总复杂度为 \(O(N+m)\)

2.6 总代码

#include<bits/stdc++.h>
#define dd double
#define ld long double
#define ll long long
#define uint unsigned int
#define ull unsigned long long
#define N 1000010
#define M number
using namespace std;

const int INF=0x3f3f3f3f;

template<typename T>  inline void read(T &x) {
    x=0; int f=1;
    char c=getchar();
    for(;!isdigit(c);c=getchar()) if(c == '-') f=-f;
    for(;isdigit(c);c=getchar()) x=x*10+c-'0';
    x*=f;
}

struct node{
    int ch[26],cnt,end,fail;
};

struct AC_automaton{
    node p[N];int tot;
    inline AC_automaton(){tot=0;}
    inline void insert(char *s){
        int now=0,len=strlen(s);
        for(int i=0;i<len;i++){
            int k=s[i]-'a';
            if(!p[now].ch[k]) p[now].ch[k]=++tot;
            now=p[now].ch[k];
        }
        p[now].end++;
    }
    inline void build(){
        queue<int> q;
        for(int i=0;i<26;i++) if(p[0].ch[i]) q.push(p[0].ch[i]);
        while(q.size()){
            int top=q.front();q.pop();
            for(int i=0;i<26;i++){
                if(p[top].ch[i]) p[p[top].ch[i]].fail=p[p[top].fail].ch[i],q.push(p[top].ch[i]);
                else p[top].ch[i]=p[p[top].fail].ch[i];
            }
        }
    }
    inline int query(char *t){
        int now=0,res=0,len=strlen(t);
        for(int i=0;i<len;i++){
            now=p[now].ch[t[i]-'a'];
            for(int j=now;j&&~p[j].end;j=p[j].fail) res+=p[j].end,p[j].end=-1;
        }
        return res;
    }
};
AC_automaton ac;

int n;
char s[N],t[N];

int main(){
    read(n);
    for(int i=1;i<=n;i++) scanf("%s",s),ac.insert(s);
    scanf("%s",t);
    ac.build();printf("%d",ac.query(t));
    return 0;
}

2.7 加强版

洛谷上有两次 AC 自动机加强版,分别是:

例题 1例题 2

对于例题 1,你只要把每个字符串出现的次数记下来排序就可以了,代码如下(码风略有点不同),注意 \(end\) 的含义稍微有点变化。

#include<iostream>
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<sstream>
#include<queue>
#include<map>
#include<vector>
#include<set>
#include<deque>
#include<cstdlib>
#include<ctime>
#define dd double
#define ld long double
#define ll long long
#define ull unsigned long long
#define N 160
#define M 10001000
using namespace std;
 
const int INF=0x3f3f3f3f;
 
inline int read(){
	int x=0,f=1;
	char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}

struct rode{
    int id,sum;
    char s[N];
    inline bool operator < (const rode b){
        if(sum!=b.sum) return sum>b.sum;
        return id<b.id;
    }
};
rode a[N];

char t[M];

struct ACzdj{
    int tr[N*100][26],cnt;
    int end[N*100];int fail[N*100];
    
    inline void insert(char *s,int id){
        int p=0;
        for(int i=0;s[i];i++){
            int k=s[i]-'a';
            if(!tr[p][k]) tr[p][k]=++cnt;
            p=tr[p][k];
        }
        end[p]=id;
    }
    
    inline void build(){
        queue<int> q;
        memset(fail,0,sizeof(fail));
        for(int i=0;i<26;i++) if(tr[0][i]) q.push(tr[0][i]);
        while(q.size()){
            int top=q.front();q.pop();
            for(int i=0;i<26;i++){
                if(tr[top][i]){
                    fail[tr[top][i]]=tr[fail[top]][i];
                    q.push(tr[top][i]);
                }
                else tr[top][i]=tr[fail[top]][i];
            }
        }
    }
    
    inline void clear(){
        memset(tr,0,sizeof(tr));
        memset(end,0,sizeof(end));
        memset(fail,0,sizeof(fail));
        cnt=0;
    }
    
    inline int query(char *t){
        int p=0,res=0;
        for(int i=0;t[i];i++){
            p=tr[p][t[i]-'a'];
            for(int j=p;j;j=fail[j])
                a[end[j]].sum++;
        }
        return res;
    }
};
ACzdj ac;
 
int main(){
    while(1){
        memset(a,0,sizeof(a));
        ac.clear();
        int n=read();
        if(!n) break;
        for(int i=1;i<=n;i++){
            scanf("%s",a[i].s);
            ac.insert(a[i].s,i);
            a[i].id=i;
        }
        ac.build();
        scanf("%s",t);
        ac.query(t);
        sort(a+1,a+n+1);
        printf("%d\n",a[1].sum);
        int now=1,minn=a[1].sum;
        while(a[now].sum==minn){
            printf("%s\n",a[now].s);
            now++;
        }
    }
}

对于例题 2,这个题我们在查询的时候直接这样像例题 1 做,会超时,正确的做法是我们把 fail 树建出来,然后在树上 dp 合并就可以了。

代码:

#include<iostream>
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<sstream>
#include<queue>
#include<map>
#include<vector>
#include<set>
#include<deque>
#include<cstdlib>
#include<ctime>
#define dd double
#define ld long double
#define ll long long
#define ull unsigned long long
#define N 1600000
#define M 10001000
using namespace std;
 
const int INF=0x3f3f3f3f;
 
inline int read(){
	int x=0,f=1;
	char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}

struct edge{
    int to,next;
    inline void intt(int to_,int ne_){
        to=to_;next=ne_;
    }
};
edge li[N*10];
int head[N*10],tail;

inline void add(int from,int to){
    li[++tail].intt(to,head[from]);
    head[from]=tail;
}

struct rode{
    int id,sum,belong,end;
};
rode a[N];

struct ACzdj{
    int tr[N*10][26],cnt,size[N*10];
    int end[N*10];int fail[N*10];
    
    inline int insert(char *s,int id){
        int p=0;
        for(int i=0;s[i];i++){
            int k=s[i]-'a';
            if(!tr[p][k]) tr[p][k]=++cnt;
            p=tr[p][k];
        }
        a[id].end=p;
        if(!end[p]) end[p]=id;
        else return end[p];
        return id;
    }
    
    inline void build(){
        queue<int> q;
        memset(fail,0,sizeof(fail));
        for(int i=0;i<26;i++) if(tr[0][i]) q.push(tr[0][i]);
        while(q.size()){
            int top=q.front();q.pop();
            for(int i=0;i<26;i++){
                if(tr[top][i]){
                    fail[tr[top][i]]=tr[fail[top]][i];
                    q.push(tr[top][i]);
                }
                else tr[top][i]=tr[fail[top]][i];
            }
        }
    }
    
    inline void clear(){
        memset(tr,0,sizeof(tr));
        memset(end,0,sizeof(end));
        memset(fail,0,sizeof(fail));
        cnt=0;
    }
    
    inline int query(char *t){
        int p=0,res=0;
        for(int i=0;t[i];i++){
            p=tr[p][t[i]-'a'];
//            for(int j=p;j;j=fail[j])
//                a[end[j]].sum++;
            size[p]++;
        }
        return res;
    }
};
ACzdj ac;

inline void dp(int k){
    for(int x=head[k];x;x=li[x].next){
        int to=li[x].to;
        dp(to);
        ac.size[k]+=ac.size[to];
    }
}

int n;
char s[M],t[M];

int main(){
    n=read();
    for(int i=1;i<=n;i++){
        a[i].id=i;
        scanf("%s",s);
        a[i].belong=ac.insert(s,i);
    }
    ac.build();
    scanf("%s",t);
    ac.query(t);
    for(int i=1;i<=ac.cnt;i++) add(ac.fail[i],i);
    dp(0);
    for(int i=1;i<=n;i++) printf("%d\n",ac.size[a[i].end]);
//    for(int i=1;i<=n;i++) printf("%d\n",a[a[i].belong].sum);
    return 0;
}

引用

posted @ 2021-06-23 17:49  hyl天梦  阅读(354)  评论(0编辑  收藏  举报