AC自动机小记

不知不觉这篇博客已经被我咕咕咕了一个月了。
也许我写过AC自动机的博客,但我也不知道我写没写过

前情回顾之\(kmp\)

\(kmp\)用来解决一个模式串匹配一个文本串的问题,其思路是设置失配指针,找与以当前字符的前一个字符结尾的后缀相同的最长的前缀的长度,失配时直接跳失配指针。复杂度\(O(m+n)\)

前情回顾之\(trie\)

\(trie\)就是字典树。顾名思义,将单词放在树上。这里用到的\(trie\)每个节点有26条出边,代表26个字母。每个节点是一个字母,这样一条路径上所有节点的字母就可以组成一个单词。支持插入,删除,查重balabala……

关于前情回顾部分就不多说了qwq

什么是AC自动机鸭

现在有好多模式串以及一个超长的文本串,让你求每个模式串出现的次数。
对每个模式串都来一次\(kmp\)?想法不错,可惜\(T\)了。
传说中,有一个东西叫做AC自动机,是在\(trie\)上进行\(kmp\)的神奇的东西,可以解决上面那个问题。以及它并不能帮助你自动AC

好了我们来建个AC自动机叭

上面说过,AC自动机是在\(trie\)的基础上进行\(kmp\)。放进\(trie\)中的是所有的模式串。那如何在\(trie\)上处理\(nxt\)数组呢?这里类似\(kmp\),找与当前字符节点的父节点结尾的某一段后缀相同的最长的前缀所在的位置。(在kmp中因为只有一个模式串,所以前缀长度也就是位置,trie不一样)。它的父节点的失配指针一定指向有相同前缀的一个节点,所以我们只需要看父节点的失配指针指向的节点是否有与这个节点字符相同的儿子即可。如果有,这个儿子即是它的失配指针,没有就继续跳失配指针,如果一直没有,失配指针就是\(root\)
来张图李姐李姐
模式串:ababb ababc bab

节点旁边的数字是节点编号,里面的字符就是这个节点的字符。其中0是root,规定0的所有出边都连向1,\(nxt[1]=0\)
首先,\(nxt[1]=0\)

遵循一个节点的失配指针是它父亲失配指针的对应的儿子的原则,发现1的两个儿子\(a\)\(b\)的失配指针都指向1.

寻找节点3的失配指针,走它的父亲的失配指针,到1。发现1有对应的儿子8号节点。

由于失配指针依赖父节点的失配指针,是从上往下一层一层的来,所以我们现在找9号节点的失配指针,发现是2号节点。

代码:

queue <int> q;
void pre()
{
	nxt[1]=0;
	q.push(1);
	while(!q.empty())
	{
		int u=q.front();
		q.pop();
		for(int i=0;i<26;i++)//用当前节点更新它的儿子
		{
			int v=trie[u][i];
			if(!v) 
				trie[u][i]=trie[nxt[u]][i];//对之后的节点来说,相当于直接跳失配指针
		    else
		    {
		    	q.push(v);
		    	int qwq=nxt[u];
		    	nxt[v]=trie[qwq][i];
			}
		}   
	}
}

统计答案

我们在\(trie\)上按照文本串的字符一直往下走,同时我们想统计文本串上每个字符对答案的贡献,也就是以该字符结尾,被包含的模式串的个数。根据\(nxt\)的找法,发现自己的失配指针所指的节点如果是某个模式串的结尾,那么就会对答案造成贡献。所以我们不妨一直跳失配指针来累加答案。

void getans()
{
	int now=1,len=strlen(ms),v;
	for(int i=0;i<len;i++)
	{
		int u=ms[i]-'a';
		v=trie[now][u];
		while(v)
		{
			if(en[v]==-1) break;//避免重复计算
			ans+=en[v];//可能会出现重复的单词
			en[v]=-1;
			v=nxt[v];
		}
		now=trie[now][u];
	}
}

这样最简单的AC自动机就Ok了,复杂度\(O((N+M)L)\),\(N\)是模式串个数,\(M\)是文本串长度,\(L\)是所有模式串的平均长度。
完整代码(luogu板子题):

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<vector>
#include<map>
#include<queue>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
inline ll read()
{
	char ch=getchar();
	ll x=0;bool f=0;
	while(ch<'0'||ch>'9')
	{
		if(ch=='-') f=1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9')
	{
		x=(x<<3)+(x<<1)+(ch^48);
		ch=getchar();
	}
	return f?-x:x;
}
int n,trie[1000009][29],cnt=1,nxt[1000009],ans;
char ms[1000009];
int en[1000099];
void add()//建trie树
{
	int len=strlen(ms);
	int now=1;
	for(int i=0;i<len;i++)
	{
		int ch=ms[i]-'a';
		if(!trie[now][ch])
			trie[now][ch]=++cnt;
		now=trie[now][ch];	
	}
	en[now]++;
}
queue <int> q;
void pre()
{
	nxt[1]=0;
	q.push(1);
	while(!q.empty())
	{
		int u=q.front();
		q.pop();
		for(int i=0;i<26;i++)
		{
			int v=trie[u][i];
			if(!v) 
				trie[u][i]=trie[nxt[u]][i]; 
		    else
		    {
		    	q.push(v);
		    	int qwq=nxt[u];
		    	nxt[v]=trie[qwq][i];
			}
		}   
	}
}
void getans()
{
	int now=1,len=strlen(ms),v;
	for(int i=0;i<len;i++)
	{
		int u=ms[i]-'a';
		v=trie[now][u];
		while(v)
		{
			if(en[v]==-1) break;
			ans+=en[v];
			en[v]=-1;
			v=nxt[v];
		}
		now=trie[now][u];
	}
}
int main()
{
	n=read();
	for(int i=1;i<=n;i++)
	{
		scanf("%s",ms);
		add();
	}
	for(int i=0;i<26;i++)
	 trie[0][i]=1;
	pre();
	scanf("%s",ms);
	getans();
	printf("%d\n",ans);
} 

拓扑建图优化

现在有一个duliu出题人,让你求每个模式串出现了多少次。
显然我们可以用普通的AC自动机来不断跳失配指针来更新每个点的答案。
普通的AC自动机的复杂度是\(O(\)模式串长度\(\times\)文本串长度\()\)的,在不友好的数据面前容易被卡成AC自闭机,所以我们要想办法让它变成\(O(\)模式串长度\()\)的。
普通AC自动机的复杂度都浪费在哪里了呢?结合刚才那张图看一下

我们统计答案时,要不断暴力跳失配指针,这样,8号节点就被更新了多次。我们考虑一下有没有能一次更新完8号节点的答案的方法。在普通AC自动机更新答案的过程中,8号节点被3,10,5,6号节点更新,也就是说8号节点的答案由3,10,5,6号节点转移过来。那么我们可以想到用一个标记记录下答案,然后从底下往上更新每个点的答案,顺带更新标记。
这有没有很像拓扑?所以我们可以用拓扑图来实现,将\(nxt\)\(trie\)上的边都看作是拓扑图里的边。
luogu AC自动机二次加强版的板子

#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<vector>
#include<queue>
#include<map>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
inline ull read()
{
	char ch=getchar();
	ull x=0; 
	bool f=0;
	while(ch<'0'||ch>'9')
	{
		if(ch=='-') f=1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9')
	{
		x=(x<<3)+(x<<1)+(ch^48);
		ch=getchar();
	}
	return f?-x:x;
}
int n,trie[1000009][29],cnt=1,nxt[1000009];
string ms;
int en[1000099],t,ys[1000099],in[1000099];
int ans[1000009],qaq[1000009];
void add(int lz)
{
	int len=ms.size();
	int now=1;
	for(int i=0;i<len;i++)
	{
		int ch=ms[i]-'a';
		if(!trie[now][ch])
			trie[now][ch]=++cnt;
		now=trie[now][ch];	
	}
	en[now]++;
	if(en[now]==1)
	 ys[now]=lz;//记录当前单词结尾节点的出现的第一个模式串的编号
	else   qaq[lz]=ys[now],en[now]=1;//记录一个映射,因为题目中要求每个模式串的答案,对于重复出现的模式串,编号为i,ans[i]=ans[qaq[i]]
}
queue <int> q;
void pre()
{
	nxt[1]=0;
	q.push(1);
	while(!q.empty())
	{
		int u=q.front();
		q.pop();
		for(int i=0;i<26;i++)
		{
			int v=trie[u][i];
			if(!v) 
				trie[u][i]=trie[nxt[u]][i];
		    else
		    {
		    	q.push(v);
		    	int qwq=nxt[u];
		    	nxt[v]=trie[qwq][i];
		    	in[nxt[v]]++;
			}
		}   
	}
}
int an[1000099];
void getans()
{
    int len=ms.size(),now=1;
    for(int i=0;i<len;i++)
    {
    	int c=ms[i]-'a';
    	now=trie[now][c];
    	an[now]++;//an就是一个标记
	}
	for(int i=1;i<=cnt;i++)
		if(!in[i]) q.push(i);
	while(!q.empty())
	{
		int u=q.front();
		q.pop();
		if(en[u])
		  ans[ys[u]]=an[u];
		int v=nxt[u];
		an[v]+=an[u];in[v]--;
		if(!in[v]) q.push(v);
	}	

}
int main()
{
	n=read();
	if(n==0) return 0;
	for(int i=1;i<=n;i++)
	{
	   cin>>ms;
	   add(i);
 	}
	for(int i=0;i<26;i++)
	    trie[0][i]=1;
	pre();
	cin>>ms;
	getans();
    for(int i=1;i<=n;i++)
    { 
    	if(!ans[i]) ans[i]=ans[qaq[i]];
    	printf("%d\n",ans[i]);
	}
} 
posted @ 2019-12-23 17:22  千载煜  阅读(207)  评论(0编辑  收藏  举报