P5357 【模板】AC自动机(二次加强版)
【模板】AC自动机(二次加强版)
题目描述
给你一个文本串 SS 和 nn 个模式串 T_{1..n}T
1..n
,请你分别求出每个模式串 T_iT
i
在 SS 中出现的次数。
输入格式
第一行包含一个正整数 nn 表示模式串的个数。
接下来 nn 行,第 ii 行包含一个由小写英文字母构成的字符串 T_iT
i
。
最后一行包含一个由小写英文字母构成的字符串 SS。
数据不保证任意两个模式串不相同。
输出格式
输出包含 nn 行,其中第 ii 行包含一个非负整数表示 T_iT
i
在 SS 中出现的次数。
工口发生:拓扑排序为了简便没把没权值的入队, 不能减入度了显然有问题
\(Solution\)
对于匹配串的每一个字符\(s[i]\),我们希望得到其所有以\(s[i]\)为真后缀的模式串的计数。
而以\(s[i]\)为真后缀的所有串一定在其 \(fail\)树上
于是我们通过 \(fail\) 指针往前, 遇到结束节点累计答案即可
题目不保证字符串不两两重复,所以每个节点开个 \(vector\) 来存原下标即可
然而这样会 \(TLE\)
原因很简单, 对于匹配串的每一个字符, 我们会通过 \(fail\) 指针一直跳到根节点来累计答案
显然,考虑最坏情况(比如 \(aaaaaaaaa\) ), 复杂度会达到 \(O(Len_{patern} * Dep_{trie})\) , 显然会炸
想办法优化, 显然,对于一个节点, 他通过 \(fail\) 指针到根的路径是唯一的, 我们浪费的时间主要在重复地走这些边
那么能不能打个 \(tag\) ,把匹配串所有字符都处理完再一次性走过这些路线呢
\(fail\) 的路径是一颗树, 于是我们用拓扑来计算tag即可
Code
#include<iostream>
#include<cstdio>
#include<queue>
#include<cstring>
#include<algorithm>
#include<climits>
#define LL long long
#define REP(i, x, y) for(int i = (x);i <= (y);i++)
using namespace std;
int RD(){
int out = 0,flag = 1;char c = getchar();
while(c < '0' || c >'9'){if(c == '-')flag = -1;c = getchar();}
while(c >= '0' && c <= '9'){out = out * 10 + c - '0';c = getchar();}
return flag * out;
}
const int MAX_Tot = 200010;
const int MAX_N = 2000010;
int ans[MAX_Tot];
int inq[MAX_Tot];//统计贡献fail树入度
queue<int>Que;
struct Aho{
struct state{
int nxt[26];
int fail, cnt, tag;
vector<int>mem;//待修改
}stateTable[MAX_Tot];
int size;
queue<int>Q;
void init(){
while(!Q.empty())Q.pop();
REP(i, 0, MAX_Tot - 1){
memset(stateTable[i].nxt, 0, sizeof(stateTable[i].nxt));
stateTable[i].fail = stateTable[i].cnt = stateTable[i].tag = 0;
stateTable[i].mem.clear();
}
size = 0;//root = 0
}
void insert(char *s, int Index){
int n = strlen(s), now = 0;
REP(i, 0, n - 1){
char c = s[i];
if(!stateTable[now].nxt[c - 'a'])stateTable[now].nxt[c - 'a'] = ++size;
now = stateTable[now].nxt[c - 'a'];
}
stateTable[now].cnt++;
stateTable[now].mem.push_back(Index);
}
void build_fail(){
stateTable[0].fail = -1;
Q.push(0);
while(!Q.empty()){
int u = Q.front();Q.pop();
REP(i, 0, 25){
int id = stateTable[u].nxt[i];
if(id){
if(u == 0)stateTable[id].fail = 0, inq[0]++;
else{
int v = stateTable[u].fail;
while(v != -1){
if(stateTable[v].nxt[i]){
stateTable[id].fail = stateTable[v].nxt[i], inq[stateTable[v].nxt[i]]++;
break;
}
v = stateTable[v].fail;
}
if(v == -1)stateTable[id].fail = 0, inq[0]++;
}
Q.push(id);
}
}
}
}
void match(char *s){
int n = strlen(s), now = 0;
REP(i, 0, n - 1){
char c = s[i];
if(stateTable[now].nxt[c - 'a'])now = stateTable[now].nxt[c - 'a'];
else{
int p = stateTable[now].fail;
while(p != -1 && stateTable[p].nxt[c - 'a'] == 0)p = stateTable[p].fail;
if(p == -1)now = 0;
else now = stateTable[p].nxt[c - 'a'];
}
stateTable[now].tag++;
}
}
}aho;
int num;
char S[MAX_N];
void init(){
num = RD();
aho.init();
REP(i, 1, num){
cin>>S;
aho.insert(S, i);
}
aho.build_fail();
}
void work(){
cin>>S;
aho.match(S);
REP(i, 1, aho.size)if(inq[i] == 0)Que.push(i);
while(!Que.empty()){
int u = Que.front();Que.pop();
int len = aho.stateTable[u].mem.size();
for(int i = 0; i < len; i++){
ans[aho.stateTable[u].mem[i]] += aho.stateTable[u].tag;
}
aho.stateTable[aho.stateTable[u].fail].tag += aho.stateTable[u].tag;
inq[aho.stateTable[u].fail]--;
if(inq[aho.stateTable[u].fail] == 0)Que.push(aho.stateTable[u].fail);
}
REP(i, 1, num){printf("%d\n", ans[i]);}
}
int main(){
init();
work();
return 0;
}