P5357 【模板】AC自动机(二次加强版)

【模板】AC自动机(二次加强版)

题目描述
给你一个文本串 SS 和 nn 个模式串 T_{1..n}T
1..n

,请你分别求出每个模式串 T_iT
i

在 SS 中出现的次数。

输入格式
第一行包含一个正整数 nn 表示模式串的个数。

接下来 nn 行,第 ii 行包含一个由小写英文字母构成的字符串 T_iT
i

最后一行包含一个由小写英文字母构成的字符串 SS。

数据不保证任意两个模式串不相同。

输出格式
输出包含 nn 行,其中第 ii 行包含一个非负整数表示 T_iT
i

在 SS 中出现的次数。


工口发生:拓扑排序为了简便没把没权值的入队, 不能减入度了显然有问题

\(Solution\)

对于匹配串的每一个字符\(s[i]\),我们希望得到其所有以\(s[i]\)为真后缀的模式串的计数。
而以\(s[i]\)为真后缀的所有串一定在其 \(fail\)树上
于是我们通过 \(fail\) 指针往前, 遇到结束节点累计答案即可
题目不保证字符串不两两重复,所以每个节点开个 \(vector\) 来存原下标即可

然而这样会 \(TLE\)

原因很简单, 对于匹配串的每一个字符, 我们会通过 \(fail\) 指针一直跳到根节点来累计答案
显然,考虑最坏情况(比如 \(aaaaaaaaa\) ), 复杂度会达到 \(O(Len_{patern} * Dep_{trie})\) , 显然会炸

想办法优化, 显然,对于一个节点, 他通过 \(fail\) 指针到根的路径是唯一的, 我们浪费的时间主要在重复地走这些边
那么能不能打个 \(tag\) ,把匹配串所有字符都处理完再一次性走过这些路线呢
\(fail\) 的路径是一颗树, 于是我们用拓扑来计算tag即可

Code

#include<iostream>
#include<cstdio>
#include<queue>
#include<cstring>
#include<algorithm>
#include<climits>
#define LL long long
#define REP(i, x, y) for(int i = (x);i <= (y);i++)
using namespace std;
int RD(){
    int out = 0,flag = 1;char c = getchar();
    while(c < '0' || c >'9'){if(c == '-')flag = -1;c = getchar();}
    while(c >= '0' && c <= '9'){out = out * 10 + c - '0';c = getchar();}
    return flag * out;
    }
const int MAX_Tot = 200010;
const int MAX_N = 2000010;
int ans[MAX_Tot];
int inq[MAX_Tot];//统计贡献fail树入度
queue<int>Que;
struct Aho{
	struct state{
		int nxt[26];
		int fail, cnt, tag;
		vector<int>mem;//待修改
		}stateTable[MAX_Tot];
	
	int size;
	queue<int>Q;
	
	void init(){
		while(!Q.empty())Q.pop();
		REP(i, 0, MAX_Tot - 1){
			memset(stateTable[i].nxt, 0, sizeof(stateTable[i].nxt));
			stateTable[i].fail = stateTable[i].cnt = stateTable[i].tag = 0;
			stateTable[i].mem.clear();
			}
		size = 0;//root = 0
		}
	
	void insert(char *s, int Index){
		int n = strlen(s), now = 0;
		REP(i, 0, n - 1){
			char c = s[i];
			if(!stateTable[now].nxt[c - 'a'])stateTable[now].nxt[c - 'a'] = ++size;
			now = stateTable[now].nxt[c - 'a'];
			}
		stateTable[now].cnt++;
		stateTable[now].mem.push_back(Index);
		}
		
	void build_fail(){
		stateTable[0].fail = -1;
		Q.push(0);
		while(!Q.empty()){
			int u = Q.front();Q.pop();
			REP(i, 0, 25){
				int id = stateTable[u].nxt[i];
				if(id){
					if(u == 0)stateTable[id].fail = 0, inq[0]++;
					else{
						int v = stateTable[u].fail;
						while(v != -1){
							if(stateTable[v].nxt[i]){
								stateTable[id].fail = stateTable[v].nxt[i], inq[stateTable[v].nxt[i]]++;
								break;
								}
							v = stateTable[v].fail;
							}
						if(v == -1)stateTable[id].fail = 0, inq[0]++;
						}
					Q.push(id);
					}
				}
			}
		}
		
	void match(char *s){
		int n = strlen(s), now = 0;
		REP(i, 0, n - 1){
			char c = s[i];
			if(stateTable[now].nxt[c - 'a'])now = stateTable[now].nxt[c - 'a'];
			else{
				int p = stateTable[now].fail;
				while(p != -1 && stateTable[p].nxt[c - 'a'] == 0)p = stateTable[p].fail;
				if(p == -1)now = 0;
				else now = stateTable[p].nxt[c - 'a'];
				}
			stateTable[now].tag++;
			}
		}
	}aho;
	
int num;
char S[MAX_N];
void init(){
	num = RD();
	aho.init();
	REP(i, 1, num){
		cin>>S;
		aho.insert(S, i);
		}
	aho.build_fail();
	}
void work(){
	cin>>S;
	aho.match(S);
	REP(i, 1, aho.size)if(inq[i] == 0)Que.push(i);
	while(!Que.empty()){
		int u = Que.front();Que.pop();
		int len = aho.stateTable[u].mem.size();
		for(int i = 0; i < len; i++){
			ans[aho.stateTable[u].mem[i]] += aho.stateTable[u].tag;
			}
		aho.stateTable[aho.stateTable[u].fail].tag += aho.stateTable[u].tag;
		inq[aho.stateTable[u].fail]--;
		if(inq[aho.stateTable[u].fail] == 0)Que.push(aho.stateTable[u].fail);
		}
	REP(i, 1, num){printf("%d\n", ans[i]);}
	}
int main(){
	init();
	work();
	return 0;
	}
posted @ 2021-02-03 20:10  Tony_Double_Sky  阅读(104)  评论(0编辑  收藏  举报