Loading

【瞎口胡】AC 自动机

AC 自动机用来解决多模式串匹配问题。

以下便是一个经典问题

给定 \(n\) 个模式串 \(S_1,S_2,...,S_n\) 和一个文本串 \(T\)。问有多少个模式串在文本串中出现过。

\(\sum |S_i| \leq 10^6,|T| \leq 10^6\)

考虑对模式串建出 trie。在 trie 的每个节点额外记录一个 fail,表示根到该节点表示的字符串在树中的最长后缀的节点编号。

[图没了]

对于红色箭头指向的节点,由于 \(\texttt{abc}\) 在树中的最长后缀是 \(\texttt{bc}\),所以红色节点的 fail 指向蓝色节点。特殊的,如果一个节点在树中找不到后缀,那么让它的 fail 指向根节点。

在求 fail 时可以这样写:

inline void Build_Fail(void){
	std::queue <int> q;
	while(!q.empty())
		q.pop();
	for(rr int i=0;i<26;++i){ // 第一圈的节点肯定没有 fail
		if(trie[0].next[i]){
			q.push(trie[0].next[i]);
		}
	}
	while(!q.empty()){
		int i=q.front();
		q.pop();
		for(rr int j=0;j<26;++j){
			if(!trie[i].next[j]){
				trie[i].next[j]=trie[trie[i].fail].next[j];//Trie 中没有这个点 特殊处理
				continue;
			}
			trie[trie[i].next[j]].fail=trie[trie[i].fail].next[j];//类似于 KMP 的思想
			q.push(trie[i].next[j]);//压入 继续更新
		}
	}
	return;
}

而在匹配的时候,文本串直接在 trie 上走就好了。设走完 \(i\) 次后到了一个点 \(j\),那么说明以 \(T_i\) 结尾的文本串就在 \(trie_j\) 上跳 fail 就好了。

inline int Query(char *s){
	int len=strlen(s);
	int p=0;
	int ans=0;
	for(rr int i=0;i<len;++i){
		int j=trie[p].next[s[i]-'a'];
		while(j&&~trie[j].cnt){ // 防止重复计算 & 保证时间复杂度
			ans+=trie[j].cnt;
			trie[j].cnt=-1;
            
			j=trie[j].fail;
		}
		p=trie[p].next[s[i]-'a'];
	}
	return ans;
}

优化 - 拓扑建图

对于问题的一个加强版,要求每个模式串在文本串中的出现次数。

这个时候,不能用经典问题中的 给 trie 上节点标 \(-1\) 来保证复杂度了。

因为每个节点都有一个唯一的 fail,于是将每个节点和它的 fail 连边,可以建成一个 DAG。在这个 DAG 上拓扑排序就好了。

# include <bits/stdc++.h>
# define rr
const int N=200010,INF=0x3f3f3f3f;
struct Node{
	int fail;
	int next[26];
}trie[N];
int endflag[N];
char S[N*10];
int cnt;
char c[N];
int id[N],du[N],v[N];
int n;
int ans[N];
inline int read(void){
	int res,f=1;
	char c;
	while((c=getchar())<'0'||c>'9')
		if(c=='-')f=-1;
	res=c-48;
	while((c=getchar())>='0'&&c<='9')
		res=res*10+c-48;
	return res*f;
}
inline void Insert(char *s,int x){
	int p=0,len=strlen(s);
	for(rr int i=0;i<len;++i){
		if(!trie[p].next[s[i]-'a']){
			trie[p].next[s[i]-'a']=++cnt;
		}
		p=trie[p].next[s[i]-'a'];
	}
	if(!endflag[p]){
		endflag[p]=x;
	}
	id[x]=endflag[p];
	return;
}
inline void GetFail(void){
	std::queue <int> q=std::queue <int>();
	for(rr int i=0;i<26;++i){
		if(trie[0].next[i]){
			q.push(trie[0].next[i]);
		}
	}
	while(!q.empty()){
		int x=q.front();
		q.pop();
		for(rr int i=0;i<26;++i){
			if(!trie[x].next[i]){
				trie[x].next[i]=trie[trie[x].fail].next[i];
				continue;
			}
			trie[trie[x].next[i]].fail=trie[trie[x].fail].next[i];
			++du[trie[trie[x].next[i]].fail];
			q.push(trie[x].next[i]);
		}
	}
	return;
}
inline void query(void){
	int p=0,len=strlen(S);
	for(rr int i=0;i<len;++i){
		p=trie[p].next[S[i]-'a'];
		++v[p]; // 跳到点 p 的次数
	}
	return;
}
inline void topsort(void){
	std::queue <int> q=std::queue <int> ();
	for(rr int i=1;i<=cnt;++i){
		if(!du[i]){
			q.push(i);
		}
	}
	while(!q.empty()){
		int i=q.front();
		q.pop();
		ans[endflag[i]]=v[i];
		--du[trie[i].fail];
		v[trie[i].fail]+=v[i];
		if(!du[trie[i].fail]){
			q.push(trie[i].fail);
		}
	}
}
int main(void){
	n=read();
	for(rr int i=1;i<=n;++i){
		scanf("%s",c);
		Insert(c,i);
	}
	scanf("%s",S);
	GetFail();
	query();
	topsort();
	for(rr int i=1;i<=n;++i){
		printf("%d\n",ans[id[i]]);
	}
	return 0;
}
posted @ 2021-09-11 19:59  Meatherm  阅读(40)  评论(0编辑  收藏  举报