Loading

AC自动机---好东西

前言

想学AC自动机,请先学会KMPTrie树

正文

AC自动机到底是什么呢?\(Trie+KMP\)

我们的AC自动机和KMP有什么区别呢?

AC自动机可以同时匹配很多个模式串,而KMP只可以匹配一个模式串。

那么我们来看一道例题:传送门

我们可以想到对于每一个字符串进行一次KMP,然鹅会TLE。

AC自动机就是一个算法,它可以同时进行多次KMP。

我们举一个例子:

2020-12-31 19-28-52.JPG

我们可以惊奇的发现:

2020-12-31 19-28-52.JPG

这里\(she\)的后缀\(he\),和\(he\)是一样的,这时候我们想到了什么,这和我们的KMP算法突然相似了起来?

我们可以建立一个\(fail_{cur}\)的数组,在\(she\)后的\(e\)建立一个指向\(he\)中的\(e\)的指针,和KMP中的\(next\)数组很想。像这样

2020-12-31 19-43-55.JPG

我们就可以很方便的进行匹配了。

现在我们的关键就在于如何求出\(fail\)

我们观察到这一个\(fail\)数组是可以递归地定义地,我们想用常规方法很难给出求法

这时,我们的数学归纳法就起作用了。

数学归纳法是什么意思呢?就是假设我们的上一步已经求得结果,我们用上一步的结果来求这一个结果。我们就可以从一些基础的情况推出所有的情况了。

假设我们有一点\(trie[cur][i]\),它的父亲\(cur\),若\(cur\)\(fail\)已经求出,那么只会有两种情况

  1. 我们的\(trie[cur][i]\)不为空,那么就直接把\(fail[trie[cur][i]]=fail[cur][i]\),意思就是在自己的\(fail_{cur}\)下寻找一个\(i\)代表的数组。
  2. \(trie[cur][i]\)为空,直接把这个点\(trie[cur][i]=trie[fail[cur]][i]\),直接跑到\(fail\)那里去。

这样就好了。

我们查找的时候就一个字符一个字符的找,每一次针对一个字符,不停地根据\(fail\)往上跳,直到跳不动为止。

在这个过程中,我们要注意统计一次答案,就要把\(Trie\)中的计数器\(cnt\)清空为\(-1\),代表跳到这里就跳不了了。

我们就可以轻松的打出代码

Code:

#include<iostream>
#include<queue>
using namespace std;
const int N = 1000005;
int trie[N][28], tot, num[N], fail[N];
queue<int> q;

void insert(string s) 
{
	int cur = 0;
	for(int i = 0; i < s.size(); i++) 
	{
		if(trie[cur][s[i]-'a'] == 0)
		{
			tot++; 
			trie[cur][s[i]-'a'] = tot;
		}
		cur = trie[cur][s[i]-'a'];
	}
	num[cur]++; //统计以cur结尾的单词出现的次数 
	return ;
}

void get_fail() //bfs构建fail数组 
{
	for(int i = 0; i < 26; i++) //根结点下面直接连的第一层结点,fail直接指向根结点0 
		if(trie[0][i]) 
			q.push(trie[0][i]);
			
	while(q.empty() == false) //队列中维护能够拓展fail值的结点 
	{
		int cur = q.front();
		q.pop();
		for(int i = 0; i < 26; i++) 
		{
			if(trie[cur][i])
			{
				//失配时,以trie[u][i]结尾的后缀尽量在trie中找一个与之相同的前缀(类似KMP) 
				fail[trie[cur][i]] = trie[fail[cur]][i];
				q.push(trie[cur][i]);
			}
			else //节点不存在,往上连,最多回到根结点0, 注意是trie不是fail数组 
				trie[cur][i] = trie[fail[cur]][i];
		}
	}
	return ;
}

int query(string t) //询问,t是文本串 
{
	int cur = 0, res = 0; //cur表示trie中的结点 
	for(int i = 0; i < t.size(); i++) 
	{
		cur = trie[cur][t[i] - 'a']; //获取t[i]所对应的结点 
		for(int j = cur; j && num[j] != -1; j = fail[j]) 
		{
	  		res += num[j];
			num[j] = -1; //标记为统计过 
		}
	}
	return res;
}

int main() 
{
	int n;
	string s;
	cin >> n;
	for(int i = 1; i <= n; i++) 
	{
		cin >> s;
		insert(s);
	}
	cin >> s; //文本串 
	get_fail();
	cout << query(s);
	return 0;
}

扩展

P3796 【模板】AC自动机(加强版)

我们注意到这道题中题目强调了:保证不存在两个相同的模式串

这就意味着我们的\(Trie\)树中没有一个叶子节点会有两个模式串同时经过。

那么我们用\(id\)数组可以记录是哪个模式串的叶子节点。

我们在查询的时候再统计\(num\)数组

有的人会问这不会有误差吗?

没有关系,我们只关注叶子节点的\(num\)值,其它的可能会错,但是那不重要。

就好了

Code

#include<iostream>
#include<queue>
#include<cstring>
using namespace std;
const int N = 1000005;
int trie[N][28], tot, num[N], fail[N], id[N], ans[N], n;
queue<int> q;


void insert(string s, int k) //k表示第几个模式串 
{
	int cur = 0;
	for(int i = 0; i < s.size(); i++) 
	{
		if(trie[cur][s[i]-'a'] == 0)
		{
			tot++;
			trie[cur][s[i]-'a'] = tot;
		}
		cur = trie[cur][s[i]-'a'];
	}
	id[cur] = k;
	return ;
}

void get_fail() //bfs构建fail数组 
{
	for(int i = 0; i < 26; i++)
	{
		if(trie[0][i]) 
			q.push(trie[0][i]);
	}
	while(q.empty() == false) 
	{
		int cur = q.front();
		q.pop();
		for(int i = 0; i < 26; i++) 
		{
			if(trie[cur][i])
			{
				fail[trie[cur][i]] = trie[fail[cur]][i];
				q.push(trie[cur][i]);
			}
			else
				trie[cur][i] = trie[fail[cur]][i];
		}
	}
	return ;
}

int query(string t) //询问 
{
	int cur = 0, res = 0; //cur表示trie中的结点 
	for(int i = 0; i < t.size(); i++) 
	{
		cur = trie[cur][t[i] - 'a']; //获取t[i]所对应的结点 
		for(int j = cur; j != 0; j = fail[j]) 
	  		num[j]++;
	}
	for(int i = 0; i <= tot; i++) //tot结点编号 
	{
		if(id[i] != 0)
		{
			res = max(res, num[i]);
			ans[id[i]] = num[i];
		}
	}
	return res;
}

void work()
{
	string s[155];
	for(int i = 1; i <= n; i++) 
	{
		cin >> s[i];
		insert(s[i], i);
	}
	string ss;
	cin >> ss;
	get_fail();
	int maxi = query(ss);
	cout << maxi << "\n";
	for(int i = 1; i <= n; i++)
		if(ans[i] == maxi)
			cout << s[i] << "\n";
	return ;
}

int main() 
{
	while(cin >> n && n != 0)
	{
		memset(num, 0, sizeof(num));
		memset(trie, 0, sizeof(trie));
		memset(id, 0, sizeof(id));
		memset(fail, 0, sizeof(fail));
		tot = 0;
		work();
	}
	return 0;
}

P5357 【模板】AC自动机(二次加强版)

二次加强版,顾名思义,它一定有一些别的要求

当我们用上一题的代码提交的话------\(TLE\)!

怎么办?注意:数据不保证任意两个模式串不相同。

在统计中,我们原本有id[cur] = k;

我们用vector<int> id[]来解决问题

但是还是会\(TLE\)

我们考虑一个极限状况,文本串和模式串都为aaaaaaa

那么它的\(fail\)数组就会指向自己的父亲。

我们的目标串每移动一个位置,我们就必须用\(fail\)数组往上跳\(n\)次。

那么为甚么会这么慢呢?原来是在上一道题目中它保证每一个模式串都不同,就不会出现这种极限的状况。

我们在这里思考一下,能不能在统计的时候只改变一次就可以了。

我们回顾一下\(fail\)数组的含义,它是记录最长的前后缀相等......

这就意味着我们只要在一个节点匹配成功,它\(fail\)数组所指向的点也会出现。

此时我们意识到\(fail\)指针构成了一个\(DAG\),我们就可以用一个树形DP来统计这一个节点对其他节点的贡献。

这就很nice了

代码实现很简单(指思路)

#include<iostream>
#include<queue>
#include<cstring>
using namespace std;
const int N = 200005;
int trie[N][28], tot, dp[N], fail[N], id[N], ans[N], head[N], cnt, n;
queue<int> q;

struct node
{
	int to, nxt;
}edges[N];

void addedge(int a, int b)
{
	cnt++;
	edges[cnt].to = b;
	edges[cnt].nxt = head[a];
	head[a] = cnt;
	return ;
}

void insert(string s, int k) //k表示第几个模式串 
{
	int cur = 0;
	for(int i = 0; i < s.size(); i++) 
	{
		if(trie[cur][s[i]-'a'] == 0)
		{
			tot++;
			trie[cur][s[i]-'a'] = tot;
		}
		cur = trie[cur][s[i]-'a'];
	}
	id[k] = cur; //记录第k个模式串在trie树中以cur结点结尾 
	return ;
}

void get_fail() //bfs构建fail数组 
{
	for(int i = 0; i < 26; i++)
	{
		if(trie[0][i]) 
			q.push(trie[0][i]);
	}
	while(q.empty() == false) 
	{
		int cur = q.front();
		q.pop();
		for(int i = 0; i < 26; i++) 
		{
			if(trie[cur][i])
			{
				fail[trie[cur][i]] = trie[fail[cur]][i];
				q.push(trie[cur][i]);
			}
			else
				trie[cur][i] = trie[fail[cur]][i];
		}
	}
	return ;
}

void counting(string t) //询问 
{
	int cur = 0; //cur表示trie中的结点 
	for(int i = 0; i < t.size(); i++) 
	{
		cur = trie[cur][t[i] - 'a']; //获取t[i]所对应的结点  
	  	dp[cur]++; //cur结点经过的次数 
	}
	return ;
}

void dfs(int cur, int fa) //树形DP, 统计子结点对父节点的贡献 
{
	for(int i = head[cur]; i != 0; i = edges[i].nxt)
	{
		int to = edges[i].to;
		if(to == fa) //实际上不需要记录fa, why? 
			continue;
		dfs(to, cur);
		dp[cur] += dp[to];
	}
	return ;
} 

int main()
{
	cin >> n;
	for(int i = 1; i <= n; i++) 
	{
		string s;
		cin >> s;
		insert(s, i);
	}
	string ss;
	cin >> ss;
	get_fail(); //求fail数组 
	counting(ss); //统计每个结点经过的次数,但是不再用fail跳
	
	//由fail数组建边, 做树形DP, 统计每个点的子树对其自身的贡献 
	for(int i = 1; i <= tot; i++) 
		addedge(fail[i], i);
		
	dfs(0, -1); //树形DP
	 
	for(int i = 1; i <= n; i++) //输出每个模式串i对应的trie结点id[i]被经过的次数dp[id[i]] 
		cout << dp[id[i]] << "\n";
	return 0;
}

这道题给我们了一些启发

  1. AC自动机除了一颗\(Trie\)树之外,还有一个\(fail\)
  2. 所以我们可以将树上DP,树上查分,莫队,甚至树链剖分和AC自动机结合起来形成一些毒瘤的题目

P3041 [USACO12JAN]Video Game G

这道题很好!!!

我们很容易的想出这个问题的子问题,文本串长度为\(k-1,k-2......\)

我们想到DP,定义\(dp_i\)表示长度为\(i\)的文本串能够得到的最大得分

但是这样不好转移,我们就增加一维\(dp_{i,j}\)表示字符串长度为\(i\)且在\(Trie\)树中\(j\)节点

状态伪代码

dp[i][trie[j][c]]=max(dp[i][trie[j][c]], dp[i-1][j]+val[trie[j][c]]);

这样定义我们能跑树形DP了吗?不能,这个状态有两维,有后效性

我们可以用树形DP来预处理,就是在GetFail函数中加一行val[cur]+=val[fail[cur]]

就可以啦!!!!

#include <bits/stdc++.h>
using namespace std;
const int MX = 305;
int trie[MX][26], fail[MX], val[MX], total, n, k, dp[1005][MX];
void ins(string s)
{
	int len = s.size(), cur = 0;
	for(int i = 0; i < len; i ++)
	{
		int to = s[i] - 'A';
		if(trie[cur][to] == 0)
			trie[cur][to] = ++ total;
		cur = trie[cur][to];
	}
	val[cur] ++;
	return ;
}
void GetFail()
{
	queue<int> q;
	for(int i = 0; i < 3; i++)
	{
		if(trie[0][i]) 
			q.push(trie[0][i]);
	}
	while(q.empty() == false) 
	{
		int cur = q.front();
		q.pop();
		for(int i = 0; i < 3; i++) 
		{
			if(trie[cur][i])
			{
				fail[trie[cur][i]] = trie[fail[cur]][i];
				q.push(trie[cur][i]);
			}
			else
				trie[cur][i] = trie[fail[cur]][i];
		}
		val[cur] += val[fail[cur]];
	}
	return ;
}
void DP(int k)
{
	memset(dp, 0xcf, sizeof(dp));
	for(int i = 0; i <= k; i++)
		dp[i][0] = 0;
    for(int i = 1; i <= k; i++)
        for(int j = 0; j <= total; j++)
            for(int c = 0; c < 3; c++)
                dp[i][trie[j][c]] = max(dp[i][trie[j][c]], dp[i-1][j] + val[trie[j][c]]);
	return ;    
}
int main()
{
	ios::sync_with_stdio(false);
	cin >> n >> k;
	for(int i = 1; i <= n; i ++)
	{
		string s;
		cin >> s;
		ins(s);
	}
	GetFail();
	DP(k);
	int anss = INT_MIN;
	for(int i = 0; i <= total; i ++)
		anss = max(dp[k][i], anss);
	cout << anss;
	return 0;
}

总结

AC自动机是一个好东西!!!!

题目清单:

  1. P3121 [USACO15FEB]Censoring G
  2. P5231 [JSOI2012]玄武密码
posted @ 2020-12-31 20:38  zhangwenxuan  阅读(87)  评论(0)    收藏  举报