AC自动机 基础篇

AC 自动机1

前置知识:\(KMP\),字典树。

\(AC\) 自动机,不是用来自动 \(AC\) 题目的,而是用来处理字符串问题的(虽然确实可以帮助你 \(AC\))。

这里总结了 \(AC\) 自动机三大步骤。

插入

考虑字典树,我们直接把所有模式串插入到字典树内即可,这并不困难,代码:

void ins(){
    int p=0;
    for(int i=0;str[i];i++){
        int t=str[i]-'a';
        if(!tr[p][t])tr[p][t]=++idx;
        p=tr[p][t];
    }
    cnt[p]++;
}

构建失配指针

这个比较恶心。是这样一个思想:若其父节点失配指针指向的点存在一个儿子节点代表的字符与其相同,那么这个点的失配指针指向该子节点,否则指向根。

如果说一个节点 \(i\) 没有一个儿子,那么假设他父亲的失配指针指向 \(j\)。那么把 \(i\) 的这个儿子指向 \(j\)的这个儿子(注意是儿子,不是失配指针)。

我们先说一下为什么这样做是正确的。先设从根到 \(i\) 构成的字符串为 \(s\),从根到 \(j\) 构成的字符串为 \(t\)。那么有 \(t\)\(s\) 的后缀。既然 \(s\) 是文本串的子串,\(t\)\(s\) 的子串,那么 \(t\) 是文本串的子串。所以这样跳过去,一定是不会有问题的。

然后我们说一下为什么这样做能起到优化的作用。如果你不做这样把儿子指过去的操作,那么如果你不还原文本串匹配的位置,你的匹配会出错,会少算东西;如果你还原了,那就变成暴力了。所以这样做是能优化的(虽然还是能被卡的)。

这里画个图方便理解:

那么失配指针的意义是什么呢?假设节点 \(u\) 的失配指针指向节点 \(v\),那么其含义为从根\(v\) 的路径代表的字符串 \(s\) 与从这个字符往回走 \(s\) 的长度这一段所代表的字符串相同。画个图方便理解:

上面那句话的意思就是这两个红圈部分代表的字符串相同。

代码(手写队列版):

void build(){
    int hh=0,tt=-1;
    for(int i=0;i<26;i++){
        if(tr[0][i]){
            q[++tt]=tr[0][i];
        }
    }
    while(hh<=tt){
        int t=q[hh++];
        for(int i=0;i<26;i++){
            int c=tr[t][i];
            if(!c)tr[t][i]=tr[ne[t]][i];
            else{
                ne[c]=tr[ne[t]][i];
                q[++tt]=c;
            }
        }
    }
}

查询

假设当前查询到了点 \(i\),我们考虑从根到 \(i\) 构成的字符串为 \(s\),那么如果你当前跳到节点 \(j=ne_i\),显然从根到 \(j\) 构成的字符串 \(t\)\(s\) 的一个后缀,那么既然你当前查询的字符串 \(str\) 在字典树上走到了 \(i\) 这个点,证明 \(s\)\(str\) 的子串,而因为 \(t\)\(s\) 的一个后缀,所以 \(t\)\(s\) 的子串,即 \(t\)\(str\) 的子串。如果 \(t\) 恰好是一个独立的字符串,那么我们就可以统计答案了。

代码:

int query(string s){
	int p=0,res=0;
	for(int i=0;i<s.size();i++){
		int ch=s[i]-'a';
		p=tr[p][ch];
		for(int t=p;t&&cnt[t]!=-1;t=ne[t]){
			res+=cnt[t];
			cnt[t]=-1;
		}
	}
	return res;
}

完整代码:

#include<bits/stdc++.h>
#define N 10005
#define M 1000005
#define S 205
using namespace std;
int n,tr[N*S][26],cnt[N*S],q[N*S],ne[N*S],idx,ed[N*S];
char str[M];
void ins(){
    int p=0;
    for(int i=0;str[i];i++){
        int t=str[i]-'a';
        if(!tr[p][t])tr[p][t]=++idx;//没有点就开一个 
        p=tr[p][t];
    }
    cnt[p]++;//这个点打一个结束标记 
}
void build(){
    int hh=0,tt=-1;
    for(int i=0;i<26;i++){
        if(tr[0][i]){
            q[++tt]=tr[0][i];//第一层的失配指针都指向根 
        }
    }
    while(hh<=tt){
        int t=q[hh++]; 
        for(int i=0;i<26;i++){
            int c=tr[t][i];
            if(!c)tr[t][i]=tr[ne[t]][i];//把儿子指过去 
            else{
                ne[c]=tr[ne[t]][i];//指向我父亲的失配指针指向的节点的这个儿子 
                q[++tt]=c;
            }
        }
    }
}
int query(string s){
	int p=0,res=0;
	for(int i=0;i<s.size();i++){
		int ch=s[i]-'a';
		p=tr[p][ch];
		for(int t=p;t&&cnt[t]!=-1;t=ne[t]){//到达p,则ne_t路径代表的字符串被包含,判断是否算贡献 
			res+=cnt[t];
			cnt[t]=-1;//只算一次 
		}
	}
	return res;
}
signed main(){
    int T;
    T=1;
    while(T--){
        cin>>n;
        for(int i=0;i<n;i++){
            cin>>str;
            ins();
        }
        build();
        cin>>str;
        cout<<query(str);
    }
    return 0;
}

AC 自动机2

注意这道题的字符串是互不相同的,和下一道题不一样。

我们考虑记录每一个字符串的编号和该字符串。由于字符串互不相同,所以一个点最多存储一个结束标记。而我们把这个结束标记存成这个字符串的编号 \(id\)

然后其他的东西基本上没有区别。就是在查询的时候如果这个点有结束标记,那就把这个字符串的出现次数加上 \(1\)

代码:

#include<bits/stdc++.h>
#define int long long
#define N 155
#define M 1000005
#define S 75
using namespace std;
int n,tr[N*S][26],cnt[N*S],q[N*S],ne[N*S],idx,st[N];
char str[N][S];
string s;
void ins(int id){
	int p=0;
	for(int i=0;str[id][i];i++){
		int t=str[id][i]-'a';
		if(tr[p][t]==0)tr[p][t]=++idx;
		p=tr[p][t];
	}
	cnt[p]=id;
}
void build(){
	int hh=0,tt=-1;
	for(int i=0;i<26;i++){
		if(tr[0][i]!=0){
			q[++tt]=tr[0][i];
		}
	}
	while(hh<=tt){
		int t=q[hh++];
		for(int i=0;i<26;i++){
			int c=tr[t][i];
			if(c==0){
				tr[t][i]=tr[ne[t]][i];
			}
			else{
				ne[c]=tr[ne[t]][i];
				q[++tt]=c;
			}
		}
	}
}
void qry(string s){
	int p=0,res=0;
	for(int i=0;i<s.size();i++){
		int ch=s[i]-'a';
		p=tr[p][ch];
		for(int t=p;t!=0;t=ne[t]){
			if(cnt[t]!=0)st[cnt[t]]++;
		}
	}
}
void clear(){
	idx=0;
	memset(tr,0,sizeof tr);
	memset(st,0,sizeof st);
	memset(cnt,0,sizeof cnt);
	memset(ne,0,sizeof ne);
}
signed main(){
	while(cin>>n,n){
		clear();
		for(int i=1;i<=n;i++){
			cin>>str[i];
			ins(i);
		}
		build();
		cin>>s;
		qry(s);
		int res=0;
		for(int i=1;i<=n;i++){
			res=max(res,st[i]);
		}
		cout<<res<<'\n';
		for(int i=1;i<=n;i++){
			if(st[i]==res){
				cout<<str[i]<<'\n';
			}
		}
	}
	return 0;
}

AC 自动机3

\(40\)

直接把上一题的最大值变成输出每个字符串的出现次数。可以获得 \(40\) 分,同时又红,黑,绿三种颜色。

\(76\)

发现一个事情,字符串有重复,所以我们打标记的方式需要改变。

如果这个字符串没有出现过,我们就直接插入;否则不插入。并记录这个字符串是第几种字符串。

但是你发现这样超时了。

\(100\)

考虑怎么卡 \(AC\) 自动机:

这样就变成了一个暴力,我们考虑怎么优化。

可以发现,如果更新第 \(1\)C 时,会依次跳到后两个 C。但是你在从第 \(2\)C 时,最后一个 C 又要被跳一次。如果这样构造数据,会直接被卡飞。那么,我们有没有什么办法让每个点只被经过一次呢?

显然是有的,我们可以使用拓扑排序。

我们如果找到了一个点有贡献,那么我们在这个点打一个标记,不再继续向下跳。最后跳一遍,上传这些标记。

就比方说,你现在到达了第一个 C,你在这个点打个 \(1\) 的标记,然后直接匹配文本串的下一个字符。等到最后,从第一个 C 开始往其 \(ne\) 数组指向的位置 \(j\) 跳(假设第一个 C 的节点编号为 \(i\)),然后 \(cnt_j+cnt_i\leftarrow cnt_j\)

可以发现,每个点的入度是任意的,但是出度一定为 \(1\)。所以我们可以直接按照 \(ne\) 数组指向的位置去跳。

于是我们就做完了,代码:

#include<bits/stdc++.h>
#define int long long
#define N 2000005
using namespace std;
int n,tr[N][26],id[N],ne[N],idx,cnt[N],res[N],din[N],mp[N];
string s,str;
void ins(int x){
	int p=0;
	for(int i=0;i<str.size();i++){
		int t=str[i]-'a';
		if(tr[p][t]==0)tr[p][t]=++idx;
		p=tr[p][t];
	}
	if(id[p]==0)id[p]=x;
	mp[x]=id[p];
}
void build(){
	queue<int>q;
	for(int i=0;i<26;i++){
		if(tr[0][i]!=0){
			q.push(tr[0][i]);
		}
	}
	while(!q.empty()){
		int t=q.front();
		q.pop();
		for(int i=0;i<26;i++){
			int c=tr[t][i];
			if(c==0){
				tr[t][i]=tr[ne[t]][i];
			}
			else{
				ne[c]=tr[ne[t]][i];
				din[ne[c]]++;
				q.push(c);
			}
		}
	}
}
void qry(string s){
	int p=0;
	for(int i=0;i<s.size();i++){
		int ch=s[i]-'a';
		p=tr[p][ch];
		cnt[p]++;
	}
}
void topo(){
	queue<int>q;
	for(int i=1;i<=idx;i++){
		if(din[i]==0){
			q.push(i);
		}
	}
	while(!q.empty()){
		int t=q.front(),f=ne[t];
		q.pop();
		res[id[t]]=cnt[t];
		din[f]--;
		cnt[f]+=cnt[t];
		if(din[f]==0)q.push(f);
	}
}
signed main(){
	int tot=0;
	cin>>n;
	for(int i=1;i<=n;i++){
		cin>>str;
		ins(i);
	}
	build();
	cin>>s;
	qry(s);
	topo();
	for(int i=1;i<=n;i++){
		cout<<res[mp[i]]<<'\n';
	}
	return 0;
}

单词

首先题意是统计每个单词在所有单词中出现的次数,注意单词不能拼起来。

我们定义 \(las_i\) 为从 \(i\) 节点出发沿着 \(ne\) 数组跳能跳到的第一个有中止标记的节点。于是我们沿着 \(ne\) 跳变成沿着 \(las\) 跳即可。

然后说一下怎么得到 \(las_i\),我们只需要在计算 \(ne_i\) 判断一下 \(ne_i\) 是否有终止标记,如果有那么 \(las_i\) 就是 \(ne_i\);否则是 \(las_{ne_i}\)

代码:

#include<bits/stdc++.h>
#define int long long
#define N 1000205
using namespace std;
int n,tr[N][26],las[N],res[N],ne[N],mp[N],id[N],idx;
string s,t;
void ins(int x){
	int p=0;
	for(int i=0;s[i];i++){
		int t=s[i]-'a';
		if(tr[p][t]==0)tr[p][t]=++idx;
		p=tr[p][t];
	}
	if(id[p]==0)id[p]=x;
	mp[x]=id[p];
}
void build(){
	queue<int>q;
	for(int i=0;i<26;i++){
		if(tr[0][i]!=0){
			q.push(tr[0][i]);
		}
	}
	while(!q.empty()){
		int t=q.front();
		q.pop();
		for(int i=0;i<26;i++){
			int c=tr[t][i];
			if(c==0){
				tr[t][i]=tr[ne[t]][i];
			}
			else{
				ne[c]=tr[ne[t]][i];
				if(id[ne[c]]!=0)las[c]=ne[c];
				else las[c]=las[ne[c]];
				q.push(c);
			}
		}
	}
}
void qry(string s){
	int p=0;
	for(int i=0;s[i];i++){
		if(s[i]=='|'){
			p=0;
			continue;
		}
		int t=s[i]-'a';
		p=tr[p][t];
		if(id[p]!=0)res[id[p]]++;
		int tmp=las[p];
		while(tmp){
			if(id[tmp]!=0)res[id[tmp]]++;
			tmp=las[tmp];
		}
	}
}
signed main(){
	cin>>n;
	for(int i=1;i<=n;i++){
		cin>>s;
		t+=s;
		t+="|";//为了防止跨单词统计,加上分隔字符
		ins(i);
	}
	build();
	qry(t);
	for(int i=1;i<=n;i++){
		cout<<res[mp[i]]<<'\n';
	}
	return 0;
}

这个写法在模板 \(3\) 会超时一个点,所以如果想要更稳定建议写拓扑排序,当然这个代码更短,也是可以的。

posted @ 2024-07-29 00:53  zxh923  阅读(4)  评论(0编辑  收藏  举报