关于AC自动机的一些理解 || Luogu3121 & 4824 Censoring - 哈希 - AC自动机

题目链接:https://www.luogu.com.cn/problem/P3121(4824)

题解:
4824 是 Censoring S,只需要对单模式串进行操作,3121 需要对多模式串

4824
开一个前缀hash数组,每次扫到当前点就判一下 \([i-k+1,i]\) 是否能与模式串的 hash 值相等,如果相等就删除,因为记录的是前缀所以维护很方便

3121
首先先把多模式串的 AC自动机 建出来

一些关于 AC自动机的理解:
AC自动机是基于字典树,先建立出多个模式串的trie
关键一点是 \(fail[]\) 指针,\(fail[p]\) 代表 \(p\) 这个点对应的(模式串前缀)所能匹配的其它模式串的最大后缀
有多个模式串:i she his he hers
image
如 9号结点 代表的是一个前缀 'she' ,其一个后缀能匹配上 'hers' 的一个前缀 'he',而且这个前缀是最长的,因此 \(fail[]\) 跳到这里
关于如何构建fail:需要用到父亲的fail,例如 求6号结点的fail,5号结点的fail指向了10号('i'),10号有没有 '-s' 这样的边?发现没有(如果有的话就直接将 \(tr[10][s]\) 作为6号的 fail),因此还需要跳fail,变成0号结点,发现\(tr[0][s]\) 存在 =7, 因此 fail[6]=7
但是每次匹配的时候都需要暴力跳 fail,能不能压缩一下路径呢?这就是 trie 图
例如 6号结点,his,和之前一样的思路,先跳到 10 号,如果我能由 10号直接跳到7号,就可以了
image
(注意 \(10 \rightarrow 7\) 的黑边 's')
这就是 trie图,注意引入了一些原来 trie 中不存在的结点
构建 trie图的 \(build()\)

void build(){
	queue<int>Q;
	Q.push(0);
	while(!Q.empty()){
		int u = Q.front();Q.pop();
		for(int i=0;i<26;i++)
			if(tr[u][i]){	// 如果 tr[u][i] 这个模式串前缀确实存在 
				// 这里相当于是路径压缩,因为 tr[fail[u]][i] 相当于已经是构建好的 trie图了,因此直接连就行了 
				fail[tr[u][i]] = u ? tr[fail[u]][i] : 0;
				/*
		lst[] 代表这个点为后缀的能匹配的下一个字符串位置
		如 当前点代表的是 abcd,而模式串是 abcd bcd cd那么 lst[abcd] -> bcd, lst[bcd] -> cd,这两个是为了能够计数才没有路径压缩
		lst[cd] -> lst[d],这里是路径压缩因为 d 这个后缀没有用处 
		*/
				if(val[fail[tr[u][i]]])lst[tr[u][i]] = fail[tr[u][i]];
				else lst[tr[u][i]] = lst[fail[tr[u][i]]];

				Q.push(tr[u][i]);
			}else tr[u][i] = tr[fail[u]][i];
		// 构建 trie图,tr[u][i] 不存在 就将其链接到 fail[u] 看看 fail[u] 有没有 i 这个儿子 
	}
}

回到这个题:
开一个栈,记一下每次在字典树跳对应的结点。如果当前匹配上了模式串,就不断pop栈(用数组实现栈更方便)回到删除之后的 trie 上的位置,继续跳 trie 就行了

4824:

// by SkyRainWind
#include <bits/stdc++.h>
#define mpr make_pair
#define debug() cerr<<"Yoshino\n"
#define pii pair<int,int>

using namespace std;

typedef long long ll;
typedef long long LL;

const int inf = 1e9, INF = 0x3f3f3f3f, p=19260817, mod = 998244353, maxn = 1e6+5;

ll bs[maxn];
ll hs, hs1[maxn];
char s[maxn], t[maxn];

ll geth(int l,int r){return (hs1[r] - hs1[l-1]*bs[r-l+1]%mod + mod) % mod;}

signed main(){
	scanf("%s",s+1);
	scanf("%s",t+1);
	bs[0]=1;
	int n = strlen(s + 1);
	for(int i=1;i<=n;i++)bs[i]=1ll*bs[i-1]*p%mod;
	int m=strlen(t+1);for(int i=1;i<=m;i++)hs=(hs*p%mod+t[i])%mod;
	
	int times = 0;
	int j = 0;
	string ans;ans.resize(n + 5);
	for(int i=1;i<=n;i++){
		++ j;
		hs1[j] = (hs1[j-1]*p%mod + s[i]) % mod;
		ans[j] = s[i];
		if(j>=m && geth(j-m+1, j) == hs){
			++ times;
			j -= m;
		}
	}
	int nn = n - times*m;
	for(int i=1;i<=nn;i++)cout<<ans[i];

	return 0;
}

3121:

// by SkyRainWind
#include <bits/stdc++.h>
#define mpr make_pair
#define debug() cerr<<"Yoshino\n"
#define pii pair<int,int>

using namespace std;

typedef long long ll;
typedef long long LL;

const int inf = 1e9, INF = 0x3f3f3f3f, maxn = 5e5+5;

int n, m;
char t[maxn], s[maxn];

struct AC{
	int lst[maxn];
	int tr[maxn][27],cnt;
	int val[maxn];	
	int fail[maxn];
	
	AC(){cnt=0;memset(val,0,sizeof val);}
	
	void insert(char *s){
		int p=0;
		int ns = strlen(s + 1);
		for(int i=1;i<=ns;i++){
			int k = s[i] - 'a';
			if(!tr[p][k])tr[p][k] = ++ cnt;
			p = tr[p][k];
		}
		val[p] = ns;
	}
	
	void build(){
		queue<int>Q;
		Q.push(0);
		while(!Q.empty()){
			int u = Q.front();Q.pop();
			for(int i=0;i<26;i++)
				if(tr[u][i]){	// 如果 tr[u][i] 这个模式串前缀确实存在 
					// 这里相当于是路径压缩,因为 tr[fail[u]][i] 相当于已经是构建好的 trie图了,因此直接连就行了 
					fail[tr[u][i]] = u ? tr[fail[u]][i] : 0;
		/*
		lst[] 代表这个点为后缀的能匹配的下一个字符串位置
		如 当前点代表的是 abcd,而模式串是 abcd bcd cd那么 lst[abcd] -> bcd, lst[bcd] -> cd,这两个是为了能够计数才没有路径压缩
		lst[cd] -> lst[d],这里是路径压缩因为 d 这个后缀没有用处 
		*/
					if(val[fail[tr[u][i]]])lst[tr[u][i]] = fail[tr[u][i]];
					else lst[tr[u][i]] = lst[fail[tr[u][i]]];
					
					Q.push(tr[u][i]);
				}else tr[u][i] = tr[fail[u]][i];
				// 构建 trie图,tr[u][i] 不存在 就将其链接到 fail[u] 看看 fail[u] 有没有 i 这个儿子 
		}
	}
	
	int stk[maxn], tp = 0;
	char stkk[maxn];
	void query(char *t){
		int p=0, res=0;
		int nt = strlen(t + 1);
		for(int i=1;i<=nt;i++){
			p = tr[p][t[i] - 'a'];
			stk[++ tp] = p;
			stkk[tp] = t[i];
			if(val[p]){
				tp -= val[p];
				p = stk[tp];
				continue;
			}
		}
		for(int i=1;i<=tp;i++)putchar(stkk[i]);puts("");
	}
}ac;

signed main(){
	scanf("%s",t + 1);
	n = strlen(t+1);
	scanf("%d",&m);
	for(int i=1;i<=m;i++){
		scanf("%s",s + 1);
		ac.insert(s);
	}
	ac.build();
	
	ac.query(t);

	return 0;
}
posted @ 2023-01-25 12:08  SkyRainWind  阅读(41)  评论(0编辑  收藏  举报