AC自动机

AC自动机(Aho-Corasick自动机)

部分资料来自

https://www.cnblogs.com/cjyyb/p/7196308.html

http://www.cppblog.com/menjitianya/archive/2014/07/10/207604.html

https://www.luogu.com.cn/problemnew/solution/P5357

问题模式

给定多个模式串和一个目标串

问有关模式串匹配问题

如果没有AC自动机,你可能需要对n个模板串分别求一趟KMP,但是复杂度过高,而AC自动机可以一次匹配,效率更优秀

俗话称 AC自动机=KMP+Trie

模型的应用

很多网站,游戏都有敏感词过滤功能,其底层实现也无非就是ac自动机

解决步骤

1.将各模式串构建起Trie树

此部分可看Trie相关芝士~

2.构建失配指针(核心部分)

他的遍历方式是利用BFS

至于为什么使用BFS,下文有提及,他与构建失配指针以及更新访问域的先后有关

注意,失配指针很多文章没有说清楚真正的含义,它实际上有两部分

Part1.强制失配指针(伪失配处理,failptr)

先看上图,理解一下它的含义,它与KMP里面的next数组作用相似:

对于节点x而言,它的失配指针指向的节点标号为u,则有

自上而下形成的字符串中,x对应字符串的最长后缀等于u对应字符串的最长前缀

(如图红线,最长后缀是she中的he,也是her的最长前缀he)

它的作用在于在匹配一段相对较长的模式串时,可能其后缀蕴含了一段(或多段)其他的模式串,当前者匹配成功,后者同样也能被匹配出来

如offset,set,你现在正在遍历offset,它后缀中含有set,你正在匹配offset的时候不可能说我停下来去check一下它后缀情况如何。取而代之的是,我们使用failptr去假装它失配了,人为强制地让这个指针去跳转检索一下,看看当前后缀当中是否能匹配到其他模式串

构建failptr对应代码:

trie[trie[curpos].vis[i]].failptr=trie[trie[curpos].failptr].vis[i];

当前踩在curpos这个父节点,父节点curpos失配指针指向节点q,我们现在访问curpos的子节点vis,如果它是存在的,那么vis节点的失配指针指向是q节点对应字符的vis节点

这一段代码实现了三个功能,对应q节点的三种不同情况:

1.若非空,那么我们就成功使得原来最长后缀++

2.若此节点为空,且指向0,就直接指向了root

3.若此节点为空,且访问域被更新过,那么就是把访问域的地址给了过去(后面讲访问域)

这里啰嗦一句,这里体现了为什么我们节点之间是通过标号来构建关系(每新构建一个新的节点,标号加1)

其优势在于,我们所有指向空的表示为0,Trie树的虚根也记为0,那么在构建失配指针的时候会让代码显得很简洁,不会像指针型这一个null,那一个null


Part2.访问域的更新(后继状态的更新,遍历跳转指针的更新,vis[26])

我们再来理解一下访问域的更新

它其实表示的是节点后继状态的更新。Trie树的每个节点均有26个访问域,对应26个字母。

它既可以储存指向代表含真正字符的子节点地址,又可能是空

倘若它是空的,那就白白浪费了一块空间。类似线索二叉树,AC自动机将这块空的区域利用了起来,加快跳转效率

如上图,我们构建了含有模式串abc,bc,s的trie树,菱形s是c对应的vis访问域,我们现在匹配abcs

当我们匹配abcs时,abc匹完了,现在多个s,接下来该跳转到哪里嘞?

显然c访问域为空,就会匹配指针无所适从,不知道下一个s何去何从

解决问题的代码,当vis指向空:

trie[curpos].vis[i]=trie[trie[curpos].failptr].vis[i];

我们建树的时候将这个空的区域指向真实存在的元素,并且将这个对应的地址传递下去。

这时BFS的优势就体现了出来,我给bc串更新访问域vis['s']时,让他指向了s,我们再给abc更新的时候,我们赋给它bc串的vis['s'],也就是s的地址,换句话说,我们就成功的把s的地址传递下移了

这样一来,当我们的abc完成了匹配,要跳转匹配s的时候,一步到位

真正模拟上述过程很麻烦,他们互相融合交错的,分不了孰先孰后,多画图模拟几次


上述对应的代码

值得强调的是,虚根的直接子节点一定要额外优先处理(建立failptr,入队),不然会有瑕疵

void built_fail()
{	
	queue<int>curline;
	for(int i=0;i<26;i++)
	{
	    if(trie[0].vis[i])trie[trie[0].vis[i]].failptr=0,curline.push(trie[0].vis[i]);
	}
	while(!curline.empty())
	{
	    int curpos=curline.front();
	    curline.pop();
	    for(int i=0;i<26;i++)
	    {
	        if(trie[curpos].vis[i])
	        {
	            trie[trie[curpos].vis[i]].failptr=trie[trie	[curpos].failptr].vis[i];
                curline.push(trie[curpos].vis[i]);
            }
            else
	            trie[curpos].vis[i]=trie[trie[curpos].failptr].vis[i];
        }
    }

}

3.遍历目标串

完成上一步失配指针的配置后,匹配就变得简单了

每次沿着Trie树匹配,匹配到当前位置失配时,直接跳转到访问域所指向的位置继续进行匹配,而每次匹配过程中还要进行一次伪失配的处理,在这个过程中进行统计。

核心代码

int checkAC(string s)
{
    int ans=0,curpos=0;
    for(int i=0;i<s.size();i++)
    {
        curpos=trie[curpos].vis[s[i]-'a'];
        for(int j=curpos;j&&trie[j].sum!=-1;j=trie[j].failptr)
        {
            ans+=trie[j].sum;
            trie[j].sum=-1;
        }
    }
    return ans;
}

优化

我们上面AC自动机在进行多模匹配时是暴力跳转failptr,但这样做复杂度还是有问题

在类似于aaaaa……aaaaa这样的串中,复杂度会退化成O(模式串长度·目标串长度)为什么?因为对于每一次跳转failptr我们都只使深度减1,那样深度(深度最深是模式串长度)是多少,每一次跳的时间复杂度就是多少。那么还要乘上文本串长度,就几乎是O(模式串长度·文本串长度)的了

再举个栗子

我们匹配ABC的时候(1234),强制失配跳转failptr的时候要先经过BC(57),可是57上并没有结束点,我们要的是9上的c,所以跳转效率大打折扣

优化思路一,我自己想的(拉垮的很):

这里运用路径压缩的思想。我们强制跳转failptr的过程采用递归的方式,通过记录含有结束单词位置,不断回溯更新failptr的地址

代码如下:

void dfs(int curpos)
{
	if(!curpos){existpos=0;return;}
	dfs(trie[curpos].failptr);
	trie[curpos].failptr=existpos;
	if(trie[curpos].stringpos)total[trie[curpos].stringpos]++,existpos=curpos;
}

在后面模板题洛谷P5356中,经过我反复各类卡常优化,我唯一T的点还是T了,1.15s->1.09s

(假算法++)

优化思路二,大佬题解:

它的核心是,观察到了failptr与节点之间构成了DAG

它优先进行目标串的匹配,放弃每步暴力跳转failptr,优先给各个节点统计上遍历过的次数sum

接下来再进行在DAG上进行拓扑排序,累加统计即可

数组in记录入度

in[trie[trie[curpos].vis[i]].failptr]++;//当vis存在,并且被更新了failptr

AC自动机匹配时,先遍历目标串,统计各点的sum:

int ans=0,curpos=0;
for(register int i=0;i<s.size();i++)
{
    curpos=trie[curpos].vis[s[i]-'a'];
    trie[curpos].sum++;
}

DAG上拓扑累加:

queue<int>curline;
for(int i=1;i<=cnt;i++)if(!in[i])curline.push(i);//初始化拓扑队列
while(!curline.empty())
{
    int x=curline.front();curline.pop();
    total[trie[x].stringpos]=trie[x].sum;//记录对应模式串出现的次数
    int y=trie[x].failptr;in[y]--;
    trie[y].sum+=trie[x].sum;//DAG上状态的转移
    if(in[y]==0)curline.push(y);//入度为0,入队
}

这样一来复杂度就成了O(max(模式串长度,文本串长度))

详细代码见下面P5357

优化效果果然大幅度提升


三套模版(第三套效率比较好)

1.洛谷模版(弱化版)P3808

查询有几个模式串出现在目标串中

#include<iostream>
#include<cstring>
#include<cstdio>
#include<string>
#include<queue>
using namespace std;
#define INF 1e10+5
#define maxn 1000005
#define minn -105
#define ll long long int
#define ull unsigned long long int
#define uint unsigned int
struct trienode
{
    int failptr;
    int vis[26];
    int sum;
    trienode(){memset(vis,0,sizeof(vis));failptr=0;sum=0;}
};
trienode trie[maxn];
int cnt=0;
void built(string s)//build trie
{
    int curpos=0;
    for(int i=0;i<s.size();i++)
    {
        if(!trie[curpos].vis[s[i]-'a'])trie[curpos].vis[s[i]-'a']=++cnt;
        curpos=trie[curpos].vis[s[i]-'a'];
    }
    trie[curpos].sum++;
}
void built_fail()//build fail_ptr
{	
	queue<int>curline;
	for(int i=0;i<26;i++)
	{
	    if(trie[0].vis[i])trie[trie[0].vis[i]].failptr=0,curline.push(trie[0].vis[i]);
	}
	while(!curline.empty())
	{
	    int curpos=curline.front();
	    curline.pop();
	    for(int i=0;i<26;i++)
	    {
	        if(trie[curpos].vis[i])
	        {
	            trie[trie[curpos].vis[i]].failptr=trie[trie	[curpos].failptr].vis[i];
                curline.push(trie[curpos].vis[i]);
            }
            else
	            {
	            trie[curpos].vis[i]=trie[trie[curpos].failptr].vis[i];
	            }
        }
    }

}
int checkAC(string s)
{
    int ans=0,curpos=0;
    for(int i=0;i<s.size();i++)
    {
        curpos=trie[curpos].vis[s[i]-'a'];
        for(int j=curpos;j&&trie[j].sum!=-1;j=trie[j].failptr)
        {
            ans+=trie[j].sum;
            trie[j].sum=-1;
        }
    }
    return ans;
}
int main()
{
    int _t;
    string s;
    cin>>_t;
    while(_t--)
    {
        cin>>s;
        built(s);
    }
    trie[0].failptr=0;
    built_fail();
    cin>>s;
    cout<<checkAC(s)<<endl;
    return 0;
}

2.洛谷板子(强化版)P3796

查询最多出现的模式串,并输出(可能有多个)

#include<iostream>
#include<cstring>
#include<cstdio>
#include<string>
#include<queue>
#include<map>
using namespace std;
#define INF 1e10+5
#define maxn 1000005
#define minn -105
#define ll long long int
#define ull unsigned long long int
#define uint unsigned int
struct trienode
{
    int failptr;
    int vis[26];
    int sum;
    int stringpos;
    void strienode(){memset(vis,0,sizeof(vis));failptr=0;sum=0;stringpos=0;}
};
trienode trie[maxn];
string stringsave[maxn];
map<string,int>Map;
int maxans;
int total[maxn];
int cnt=0;

void built(string s)
{
    int curpos=0;
    for(int i=0;i<s.size();i++)
    {
        if(!trie[curpos].vis[s[i]-'a'])trie[curpos].vis[s[i]-'a']=++cnt;
        curpos=trie[curpos].vis[s[i]-'a'];
    }
    trie[curpos].sum++;
    trie[curpos].stringpos=Map[s];
}
void built_fail()
{
    queue<int>curline;
    for(int i=0;i<26;i++)
    {
        if(trie[0].vis[i])trie[trie[0].vis[i]].failptr=0,curline.push(trie[0].vis[i]);
    }
    while(!curline.empty())
    {
        int curpos=curline.front();
        curline.pop();
        for(int i=0;i<26;i++)
        {
            if(trie[curpos].vis[i])
	            {
                trie[trie[curpos].vis[i]].failptr=trie[trie[curpos].failptr].vis[i];
                curline.push(trie[curpos].vis[i]);
            }
            else
            {
                trie[curpos].vis[i]=trie[trie[curpos].failptr].vis[i];
            }
        }
    }

}
void checkAC(string s)
{
    int ans=0,curpos=0;
    for(int i=0;i<s.size();i++)
    {
        curpos=trie[curpos].vis[s[i]-'a'];
        for(int j=curpos;j;j=trie[j].failptr)
        {
            total[trie[j].stringpos]+=trie[j].sum;
            maxans=max(maxans,total[trie[j].stringpos]);
        }
    }
}
int main()
{
    int _t;
    string s;
    while(1)
    {
        cin>>_t;
        if(!_t)break;
        int index=0;
        maxans=0;
        memset(total,0,sizeof(total));
        Map.clear();
        for(int i=0;i<100000;i++)
            trie[i].strienode();
        while(_t--)
        {
            cin>>s;
            if(s.size()>=maxn)continue;
            if(!Map[s])Map[s]=index,stringsave[index]=s,index++;
            built(s);
        }
        trie[0].failptr=0;
        built_fail();
        cin>>s;
        checkAC(s);
        cout<<maxans<<endl;
        for(int i=0;i<index;i++)
        {
            if(total[i]==maxans)cout<<stringsave[i]<<endl;
        }
    }

    return 0;
}

3.洛谷模版3(二次强化版)P5357

这个有一点值得注意的是

aa与aa出现在aaa中分别认为是2次,2次,而不是4次,4次

#include<iostream>
#include<cstring>
#include<cstdio>
#include<string>
#include<queue>
using namespace std;
#define INF 1e10+5
#define maxn 1000005
#define minn -105
#define ll long long int
#define ull unsigned long long int
#define uint unsigned int
struct trienode
{
    int failptr;
    int vis[26];
    int sum;
    int stringpos;
    trienode(){memset(vis,0,sizeof(vis));failptr=0;stringpos=0;sum=0;}
};
trienode trie[maxn];
queue<int>curline;
string s;
int total[maxn],Map[maxn],in[maxn];
int cnt=0;
int existpos=0;
int indexnum=0;
void built(int p)
{
    int curpos=0;
    int len=s.size();
    for(register int i=0;i<len;i++)
    {
        if(!trie[curpos].vis[s[i]-'a'])trie[curpos].vis[s[i]-'a']=++cnt;
        curpos=trie[curpos].vis[s[i]-'a'];
    }
    if(!trie[curpos].stringpos)trie[curpos].stringpos=++indexnum;
    Map[p]=trie[curpos].stringpos;
}
void built_fail()
{
    for(register int i=0;i<26;i++)
    {
        if(trie[0].vis[i])trie[trie[0].vis[i]].failptr=0,curline.push(trie[0].vis[i]);
    }
    while(!curline.empty())
    {
        int curpos=curline.front();
        curline.pop();
        for(int i=0;i<26;i++)
        {
            if(trie[curpos].vis[i])
            {
                int x=trie[curpos].vis[i];
                trie[x].failptr=trie[trie[curpos].failptr].vis[i];
                in[trie[x].failptr]++;
                curline.push(trie[curpos].vis[i]);
            }
            else
            {
                trie[curpos].vis[i]=trie[trie[curpos].failptr].vis[i];
            }
        }
    }
}
void checkAC()
{
    int ans=0,curpos=0;
	//targetstring 的遍历
    for(register int i=0;i<s.size();i++)
    {
        curpos=trie[curpos].vis[s[i]-'a'];
        trie[curpos].sum++;
    }
	//topu on DAG
    queue<int>curline;
    for(int i=1;i<=cnt;i++)if(!in[i])curline.push(i);
    while(!curline.empty())
    {
        int x=curline.front();curline.pop();
        total[trie[x].stringpos]=trie[x].sum;
        int y=trie[x].failptr;in[y]--;
        trie[y].sum+=trie[x].sum;
        if(in[y]==0)curline.push(y);
    }
}
int main()
{
    int _t;
    cin>>_t;
    memset(total,0,sizeof(total));
    memset(in,0,sizeof(in));
    for(register int i=1;i<=_t;i++)
    {
        cin>>s;
        built(i);
    }
    trie[0].failptr=0;
    built_fail();
    cin>>s;
    checkAC();
    for(register int i=1;i<=_t;i++)cout<<total[Map[i]]<<'\n';
    return 0;
}
posted @ 2020-04-05 17:12  et3_tsy  阅读(225)  评论(0编辑  收藏  举报