luogu P3808 【模板】AC自动机(简单版)

题目背景

这是一道简单的AC自动机模板题。

用于检测正确性以及算法常数。

为了防止卡OJ,在保证正确的基础上只有两组数据,请不要恶意提交。

管理员提示:本题数据内有重复的单词,且重复单词应该计算多次,请各位注意

题目描述

给定n个模式串和1个文本串,求有多少个模式串在文本串里出现过。

输入输出格式

输入格式:

 

第一行一个n,表示模式串个数;

下面n行每行一个模式串;

下面一行一个文本串。

 

输出格式:

 

一个数表示答案

 

输入输出样例

输入样例#1: 复制
2
a
aa
aa
输出样例#1: 复制
2

说明

subtask1[50pts]:∑length(模式串)<=10^6,length(文本串)<=10^6,n=1;

subtask2[50pts]:∑length(模式串)<=10^6,length(文本串)<=10^6;

终于A了,woc爽

TLE+RE的原因竟然是把数组开成char

#include<cstdio>
#include<queue>
#include<cstring>
using namespace std;
const int  maxn = 1000004;
int n;
char s[maxn];
queue<int>que;
int ans=0;
struct Aho_Corasick_automaton {
    int sz;
    char ch[maxn][26];
    int val[maxn];
    int last[maxn];
    int fail[maxn];
    int num;
    void init() {
        memset(ch[0],0,sizeof(ch[0]));
        sz=0;
    }
    void insert(char *s) {
        int len=strlen(s);
        int u=0;
        for(int i=0; i<len; ++i) {
            int v=s[i]-'a';
            if(!ch[u][v]) {
                val[sz+1]=0;
                ch[u][v]=++sz;
            }
            u=ch[u][v];
       }
        //printf("%d\n",u);
        val[u]++;
        //printf("%d\n",val[1]);
    }
    void get_fail() {
        fail[0]=0;
        que.push(0);
        while(!que.empty()) {
            int u=que.front();
            que.pop();
            for(int i=0;i<26;i++) {
                int v=ch[u][i];
                if(!v) {
                    ch[u][i]=ch[fail[u]][i];
                    continue;
                }
                que.push(v);
                fail[v]=u ? ch[fail[u]][i]:0;
                last[v]=val[fail[v]] ? fail[v] : last[fail[v]];
            }
        }
    }
    void find(char *s) {
        int len=strlen(s);
        int u=0;
        for(int i=0; i<len; i++) {
            int c=s[i]-'a';
            u=ch[u][c];
        //    printf("%d\n",val[u]);
            if(val[u])ans+=val[u],val[u]=0;
        //    printf("%d      ***\n",ans);
            int v=u;
            while(last[v]) {
                v=last[v];
                if(val[v])ans+=val[v],val[v]=0;
            }
        }
    }
} ac;

int main() {
    scanf("%d",&n);
    ac.init();
    for(int i=1; i<=n; i++) {
        scanf("%s",s);
        ac.insert(s);
    }
    ac.get_fail();
    scanf("%s",s);
    ac.find(s);
    //for(int i=1;i<=ac.sz;++i)printf("%d ",ac.val[i]);
    printf("%d\n",ans);
    return 0;
}

 

posted @ 2017-11-25 20:39  zzzzx  阅读(200)  评论(0编辑  收藏  举报