SCL-- AC 自动机

2015-05-21 23:53:19

总结:来小结一下 AC 自动机的模板以及算法细节。

  

 1 const int RA = 26;
 2 const int MAXN = 500010;
 3 
 4 struct AC_auto{
 5     int next[MAXN][RA],f[MAXN],cnt[MAXN],last[MAXN];
 6     int root,tot;
 7     void new_node(int p){ //初始化新节点
 8         for(int i = 0; i < RA; ++i) next[p][i] = 0; //以0表示无后继
 9         f[p] = cnt[p] = last[p] = 0; //失配、结尾、后缀链接数组清零
10     }
11     void init(){ //ac自动机初始化函数
12         root = tot = 0; //初始化root和节点计数器
13         new_node(root); //初始化根
14     }
15     void insert(char *str){ //插入函数,类似字典数
16         int len = strlen(str);
17         int p = root;
18         for(int i = 0; i < len; ++i){
19             int id = str[i] - 'a';
20             if(next[p][id] == 0){ //发现没后继节点
21                 next[p][id] = ++tot;
22                 new_node(tot);
23             }
24             p = next[p][id]; //走边
25         }
26         cnt[p]++; //给单词结尾标记+1
27     }
28     void getfail(){
29         queue<int> Q; //BFS
30         f[root]= root; //初始化根的失配指向自己
31         for(int i = 0; i < RA; ++i){ //初始化队列
32             int u = next[root][i];
33             if(u) Q.push(u);
34         }
35         while(!Q.empty()){
36             int x = Q.front(); Q.pop();
37             for(int i = 0; i < RA; ++i){
38                 int u = next[x][i];
39                 if(u == 0){
40                     next[x][i] = next[f[x]][i]; //事实上是把失配提前计算出来
41                     continue;
42                 }
43                 Q.push(u);
44                 int v = f[x]; //寻找失配节点
45                 while(v && next[v][i] == 0) v = f[v];
46                 f[u] = next[v][i]; //找到失配后继点 / 返回根节点
47                 last[u] = cnt[f[u]] ? f[u] : last[f[u]]; //后缀链接
48             }
49         }
50     }
51     int find(){ //查询函数
52         int len = strlen(S);
53         int p = root,res = 0; //从根出发,res是匹配计数器
54         for(int i = 0; i < len; ++i){
55             int id = S[i] - 'a';
56             p = next[p][id];
57             int cur = p;
58             while(cur != root){ //顺着失配走,遍历每个匹配模板串
59                 res += cnt[cur];
60                 cnt[cur] = 0; //匹配完要清空
61                 cur = last[cur]; //顺着后缀链接走(用fail也行,但效率低)
62             }
63         }
64         return res;
65     }
66 }ac;

 

相关的几个模板题:

(1)hdu 2222,纯模板辣。

  如果不采用 last 数组优化将会跑出 700+ MS,加了 last 优化能做到 200+ MS

#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cmath>
#include <vector>
#include <map>
#include <set>
#include <stack>
#include <queue>
#include <string>
#include <iostream>
#include <algorithm>
using namespace std;

#define getmid(l,r) ((l) + ((r) - (l)) / 2)
#define MP(a,b) make_pair(a,b)
#define PB(a) push_back(a)

typedef long long ll;
typedef pair<int,int> pii;
const double eps = 1e-8;
const int INF = (1 << 30) - 1;
const int RA = 26;
const int MAXN = 500010;

char S[1000010];
int T,N;

struct AC_auto{
    int next[MAXN][RA],f[MAXN],cnt[MAXN],last[MAXN];
    int root,tot;
    void new_node(int p){ //初始化新节点
        for(int i = 0; i < RA; ++i) next[p][i] = 0; //以0表示无后继
        f[p] = cnt[p] = last[p] = 0; //失配、结尾、后缀链接数组清零
    }
    void init(){ //ac自动机初始化函数
        root = tot = 0; //初始化root和节点计数器
        new_node(root); //初始化根
    }
    void insert(char *str){ //插入函数,类似字典数
        int len = strlen(str);
        int p = root;
        for(int i = 0; i < len; ++i){
            int id = str[i] - 'a';
            if(next[p][id] == 0){ //发现没后继节点
                next[p][id] = ++tot;
                new_node(tot);
            }
            p = next[p][id]; //走边
        }
        cnt[p]++; //给单词结尾标记+1
    }
    void getfail(){
        queue<int> Q; //BFS
        f[root]= root; //初始化根的失配指向自己
        for(int i = 0; i < RA; ++i){ //初始化队列
            int u = next[root][i];
            if(u) Q.push(u);
        }
        while(!Q.empty()){
            int x = Q.front(); Q.pop();
            for(int i = 0; i < RA; ++i){
                int u = next[x][i];
                if(u == 0){
                    next[x][i] = next[f[x]][i]; //事实上是把失配提前计算出来
                    continue;
                }
                Q.push(u);
                int v = f[x]; //寻找失配节点
                while(v && next[v][i] == 0) v = f[v];
                f[u] = next[v][i]; //找到失配后继点 / 返回根节点
                last[u] = cnt[f[u]] ? f[u] : last[f[u]]; //后缀链接
            }
        }
    }
    int find(){ //查询函数
        int len = strlen(S);
        int p = root,res = 0; //从根出发,res是匹配计数器
        for(int i = 0; i < len; ++i){
            int id = S[i] - 'a';
            p = next[p][id];
            int cur = p;
            while(cur != root){ //顺着失配走,遍历每个匹配模板串
                res += cnt[cur];
                cnt[cur] = 0; //匹配完要清空
                cur = last[cur]; //顺着后缀链接走(用fail也行,但效率低)
            }
        }
        return res;
    }
}ac;

int main(){
    scanf("%d",&T);
    while(T--){
        ac.init();
        scanf("%d",&N);
        for(int i = 1; i <= N; ++i){
            scanf("%s",S);
            ac.insert(S);
        }
        ac.getfail();
        scanf("%s",S);
        printf("%d\n",ac.find());
    }
    return 0;
}
View Code

 

(2)hdu 2896,依旧模板,只是要给每个单词结尾一个编号。

#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cmath>
#include <vector>
#include <map>
#include <set>
#include <stack>
#include <queue>
#include <string>
#include <iostream>
#include <algorithm>
using namespace std;

#define getmid(l,r) ((l) + ((r) - (l)) / 2)
#define MP(a,b) make_pair(a,b)
#define PB(a) push_back(a)

typedef long long ll;
typedef pair<int,int> pii;
const double eps = 1e-8;
const int INF = (1 << 30) - 1;
const int MAXN = 100010;
const int RA = 128;

int N,M;
int vis[510];
char s[10010];

struct AC_auto{
    int next[MAXN][RA],f[MAXN],num[MAXN],last[MAXN];
    int root,tot;
    void new_node(int p){
        for(int i = 0; i < RA; ++i) next[p][i] = 0;
        f[p] = num[p] = last[p] = 0;
    }
    void init(){
        root = tot = 0;
        new_node(root);
    }
    void insert(int k,char *str){
        int len = strlen(str);
        int p = root;
        for(int i = 0; i < len; ++i){
            int id = (int)str[i];
            if(next[p][id] == 0){
                next[p][id] = ++tot;
                new_node(tot);
            }
            p = next[p][id];
        }
        num[p] = k;
    }
    void getfail(){
        queue<int> Q;
        f[root] = root;
        for(int i = 0; i < RA; ++i){
            int u = next[root][i];
            if(u) Q.push(u);
        }
        while(!Q.empty()){
            int x = Q.front(); Q.pop();
            for(int i = 0; i < RA; ++i){
                int u = next[x][i];
                if(u == 0){
                    next[x][i] = next[f[x]][i];
                    continue;
                }
                Q.push(u);
                int v = f[x];
                while(v && next[v][i] == 0) v = f[v];
                f[u] = next[v][i];
                last[u] = num[f[u]] ? f[u] : last[f[u]];
            }
        }
    }
    int find(char *str){
        int len = strlen(str);
        int p = root,res = 0;
        for(int i = 0; i < len; ++i){
            int id = (int)str[i];
            //while(p && next[p][id] == 0) p = f[p];
            p = next[p][id];
            int cur = p;
            while(cur != root){
                if(num[cur]) res = 1;
                vis[num[cur]] = 1;
                cur = last[cur];
            }
        }
        return res;
    }
}ac;

int main(){
    while(scanf("%d",&N) != EOF){
        ac.init();
        for(int i = 1; i <= N; ++i){
            scanf("%s",s);
            ac.insert(i,s);
        }
        ac.getfail();
        int ans = 0;
        scanf("%d",&M);
        for(int i = 1; i <= M; ++i){
            scanf("%s",s);
            memset(vis,0,sizeof(vis));
            int cur = ac.find(s);
            if(cur){
                ans++;
                printf("web %d:",i);
                for(int j = 1; j <= N; ++j)
                    if(vis[j]) printf(" %d",j);
                puts("");
            }
        }
        printf("total: %d\n",ans);
    }
    return 0;
}
View Code

 

(3)hdu 3065, 模板,计算每个模式串在文本串中的出现次数,注意读入要用 gets

#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cmath>
#include <vector>
#include <map>
#include <set>
#include <stack>
#include <queue>
#include <string>
#include <iostream>
#include <algorithm>
using namespace std;

#define getmid(l,r) ((l) + ((r) - (l)) / 2)
#define MP(a,b) make_pair(a,b)
#define PB(a) push_back(a)

typedef long long ll;
typedef pair<int,int> pii;
const double eps = 1e-8;
const int INF = (1 << 30) - 1;
const int MAXN = 2000010;
const int RA = 26;

int N;
char s[MAXN],tmp[1010][60],num[1010];

struct AC_auto{
    int next[50010][RA],f[50010],cnt[50010],last[50010];
    int root,tot;
    void new_node(int p){
        for(int i = 0; i < RA; ++i) next[p][i] = 0;
        f[p] = cnt[p] = last[p] = 0;
    }
    void init(){
        root = tot = 0;
        new_node(root);
    }
    void insert(int k,char *str){
        int len = strlen(str);
        int p = root;
        for(int i = 0; i < len; ++i){
            int id = str[i] - 'A';
            if(next[p][id] == 0){
                next[p][id] = ++tot;
                new_node(tot);
            }
            p = next[p][id];
        }
        cnt[p] = k;
    }
    void getfail(){
        queue<int> Q;
        f[root] = root;
        for(int i = 0; i < RA; ++i){
            int u = next[root][i];
            if(u) Q.push(u);
        }
        while(!Q.empty()){
            int x = Q.front(); Q.pop();
            for(int i = 0; i < RA; ++i){
                int u = next[x][i];
                if(u == 0){
                    next[x][i] = next[f[x]][i];
                    continue;
                }
                Q.push(u);
                int v = f[x];
                while(v && next[v][i] == 0) v = f[v];
                f[u] = next[v][i];
                last[u] = cnt[f[u]] ? f[u] : last[f[u]];
            }
        }
    }
    void find(char *str){
        int len = strlen(str);
        int p = root;
        for(int i = 0; i < len; ++i){
            int id = str[i] - 'A';
            if(id < 0 || id >= 26){
                p = root;
                continue;
            }
            p = next[p][id];
            int cur = p;
            while(cur != root){
                num[cnt[cur]]++;
                cur = last[cur];
            }
        }
    }
}ac;

int main(){
    while(scanf("%d",&N) != EOF){
        memset(num,0,sizeof(num));
        ac.init();
        for(int i = 1; i <= N; ++i){
            scanf("%s",tmp[i]);
            ac.insert(i,tmp[i]);
        }
        ac.getfail();
        getchar();
        gets(s);
        ac.find(s);
        for(int i = 1; i <= N; ++i) if(num[i]){
            printf("%s: %d\n",tmp[i],num[i]);
        }
    }
    return 0;
}
View Code

 

posted @ 2015-05-22 00:22  Naturain  阅读(139)  评论(0编辑  收藏  举报