AC自动机

 

题目描述

给定n个模式串和1个文本串,求有多少个模式串在文本串里出现过。

输入输出格式

输入格式:

 

第一行一个n,表示模式串个数;

下面n行每行一个模式串;

下面一行一个文本串。

 

输出格式:

 

一个数表示答案

 

输入输出样例

输入样例#1:
2
a
aa
aa
输出样例#1:
2
模板,ac自动机
last版本
#include<cstdio>
#include<queue>
#include<cstring>
using namespace std;
const int  maxn = 1000001;
int n;
char s[maxn];

int ans=0;
struct Aho_Corasick_automato{
    int sz;
    int ch[maxn][26];
    int val[maxn];
    int last[maxn];
    int fail[maxn];
    int num;
    void init() {
        memset(ch[0],0,sizeof(ch[0]));
        sz=1;
    }
    void insert(char *s) {
        int len=strlen(s);
        int u=0;
        for(int i=0; i<len; ++i) {
            int v=(s[i]-'a');
            if(!ch[u][v]) {
                memset(ch[sz],0,sizeof(ch[sz]));
                val[sz]=0;
                ch[u][v]=sz++;
            }
            u=ch[u][v];
        }
        val[u]++;
    }
    void get_fail() {
        fail[0]=0;
        queue<int>que;
        for(int i=0; i<26; i++) {
            int u=ch[0][i];
            if(u) {
                fail[u]=0;
                que.push(u);
            }
        }
        while(!que.empty()) {
            int u=que.front(); 
            que.pop();
            for(int i=0; i<26; i++) {
                int v=ch[u][i];
                if(!v) {
                    ch[u][i]=ch[fail[u]][i];
                    continue;
                }
                que.push(v);
                int k=fail[u];
                fail[v]=ch[k][i];
                last[v]=val[fail[v]] ? fail[v] : last[fail[v]];
            }
        }
    }
    void work(int x) {
        if(val[x]) {
            ans+=val[x];
            val[x]=0;
            work(last[x]);
        }
    }
    void find(char *s) {
        int len=strlen(s);
        int u=0;
        for(int i=0; i<len; i++) {
            int v=(s[i]-'a');
            if(!ch[u][v])u=fail[u];
            while(u&&!ch[u][v])
                u=fail[u];
            u=ch[u][v];
            if(val[u])
                work(u);
            else if(last[u])
                work(last[u]);
        }   
    }       
} ac;

int main() {

    scanf("%d",&n);
    ac.init();
    for(int i=1; i<=n; i++) {
        scanf("%s",s);
        ac.insert(s);
    }
    ac.get_fail();
    scanf("%s",s);
    ac.find(s);
    printf("%d\n",ans);
    return 0;
}

 

posted @ 2017-07-30 15:08  zzzzx  阅读(183)  评论(0编辑  收藏  举报