AC自动机

一种能让你AC的算法?
一种字符串匹配算法!
用来解决多字符串匹配问题的算法啦。
首先想一想KMP算法……(不知道KMP是什么?到时候我会写一篇专门的博客,敬请期待)
例如字符串aabaaaba与aba的匹配:
如果你看到了这个,说明图片挂了
在第二个位置失配,这个时候模式串可以向下跳一位:
如果你看到了这个,说明图片挂了
全部匹配成功,这个时候可以向下跳两位:
如果你看到了这个,说明图片挂了
那么像这样的过程可以看作是只有一个模式串的多串匹配。
这样可以对模式串aba建立一个图:
图挂了,自己脑补
其中点中的数字代表已经匹配到了哪一位,而边代表如果主串的下一个字符是黑色边上的字符那么就走黑色边,同时主串的当前字符变成当前字符的下一个字符,否则走红色回溯边。到达3号点的次数就是模式串aba在主串中出现的次数。
如:aabaaaba这个主串的匹配顺序是:第一个位置,匹配到了a,走黑色边;第二个位置,由于下一个应该出现的字符是b而主串出现的是a,那么走一次红色回溯边,然后走从0到1的黑色边;第三个位置,走1-2的黑色边;第四个位置,走2-3的黑色边,已经到达了3,答案+1;第五个位置应当走两边红色回溯边(因为走一遍过后应当出现的字符是b,而实际是a)……这可以被认为是一次KMP匹配过程,而建立红色边的过程就是KMP中建立next数组的过程。
可以看出,上面的过程中建立了一条链和一些红色边(回溯边),当模式串更多,这个时候就显然不能建立多个链,而是建立一棵树了(trie树)。例如,{he,she,her,shr}这4个单词(虽然最后面那个可能不是)建立的trie树就是:
自己脑补
其中打上绿色标记的是一个单词的终止节点。
那么红色边的定义就变成了:
fail指针:一个点x的fail指针指向的是离trie的根最远的点y,使得trie[root,y]=trie[z,x],其中z是x的祖先。
例如这个图的红色边就有:
脑补吧
(由于某些技术上的原因导致这张图没有前面的几个图好看了,请原谅)
那么就可以照搬上面的方法来解决这个问题了:走的通就走黑色边,否则走红色边,到达绿色节点这个点的匹配数量加1。这样就非常完美的解决了多字符串匹配问题~~~
fail指针的添加等具体细节请见代码。

代码(非指针版)

实际上代码还是非常短的,只是我写的很长而已(作为一名扩行者的提醒)……

struct ac_automaton
{
  int cnt,son[maxm+10][26],fail[maxm+10],ans;
  int num[maxm+10],q[maxm+10],head,tail,tot,b[maxn+10],w[maxm+10];
  //cnt数组记录节点的种类
  //son记录这个点的所有儿子
  //fail就是正文中的意思
  //ans记录的是有多少个串被匹配成功了
  //num记录的是一个点有多少个串
  //q是队列
  //tot记录trie中有多少个字串
  //b记录的是一个点是否被搜索到
  //w记录的是一个模式串的末尾位置

  int init()//初始化应该没有任何问题吧
  {
    cnt=0;
    ans=0;
    tot=0;
    memset(son,0,sizeof son);
    memset(fail,0,sizeof fail);
    memset(num,0,sizeof num);
    memset(b,0,sizeof b);
    return 0;
  }


  int insert(char* s)//在自动机中插入一个串s(其实就是trie树的插入方法)
  {
    int len=strlen(s),now=0;//now记录已经插入到了哪一位
    for(register int i=0; i<len; ++i)
      {
        if(!son[now][s[i]-'a'])//如果没有这个节点,就新建一个
          {
            cnt++;
            son[now][s[i]-'a']=cnt;
          }
        now=son[now][s[i]-'a'];
      }
    tot++;//将这个串的信息记录一下
    num[now]++;
    w[tot]=now;
    return 0;
  }

  int build()//这个是建立fail指针
  {
    head=0;
    tail=0;
    for(register int i=0; i<26; ++i)//这个是队列的初始化
      {
        if(son[0][i])
          {
            tail++;
            q[tail]=son[0][i];
            fail[son[0][i]]=0;
          }
      }
    while(head!=tail)//用队列来优化fail指针的计算
      {
        head++;
        int now=q[head];
        for(register int i=0; i<26; ++i)
          {
            if(son[now][i])//这个自己想一想就好了
              {
                int r=fail[now];
                while((!son[r][i])&&(r))
                  {
                    r=fail[r];
                  }
                fail[son[now][i]]=son[r][i];
                tail++;
                q[tail]=son[now][i];
              }
            else
              {
                son[now][i]=son[fail[now]][i];
              }
          }
      }
    return 0;
  }

  int find(char* s)//让这些模式串匹配主串s,就是正文中的方法
  {
    int len=strlen(s),now=0;
    for(register int i=0; i<len; ++i)
      {
        while((now)&&(!son[now][s[i]-'a']))
          {
            now=fail[now];
          }
        if(son[now][s[i]-'a'])
          {
            now=son[now][s[i]-'a'];
            int r=now;
            while(r&&(!b[r]))
              {
                b[r]=1;
                r=fail[r];
              }
          }
      }
    return 0;
  }

  int work(char* s)//统计答案
  {
    build();
    find(s);
    for(int i=1; i<=tot; ++i)
      {
        ans+=b[w[i]];
      }
    printf("%d\n",ans);
    return 0;
  }
};

代码(指针版)

struct node
{
  node *ch[26],*fail;
  int vis;
};

namespace acam
{
  node bin[maxk+10],*root,*end[maxn+10];
  int cnt;
  std::queue<node*> q;

  int build()
  {
    cnt=0;
    root=&bin[++cnt];
    return 0;
  }

  int addstr(char *s,int len,int id)
  {
    node *now=root;
    for(int i=0; i<len; ++i)
      {
        if(now->ch[s[i]-'a']==NULL)
          {
            now->ch[s[i]-'a']=&bin[++cnt];
          }
        now=now->ch[s[i]-'a'];
      }
    end[id]=now;
    return 0;
  }

  int getfail()
  {
    root->fail=root;
    for(int i=0; i<26; ++i)
      {
        if(root->ch[i]!=NULL)
          {
            root->ch[i]->fail=root;
            q.push(root->ch[i]);
          }
      }
    while(!q.empty())
      {
        node *u=q.front();
        q.pop();
        for(int i=0; i<26; ++i)
          {
            if(u->ch[i]!=NULL)
              {
                node *now=u->fail;
                while((now!=root)&&(now->ch[i]==NULL))
                  {
                    now=now->fail;
                  }
                if(now->ch[i]!=NULL)
                  {
                    u->ch[i]->fail=now->ch[i];
                  }
                else
                  {
                    u->ch[i]->fail=root;
                  }
                q.push(u->ch[i]);
              }
          }
      }
    return 0;
  }

  int getans(char *s,int len)
  {
    node *now=root;
    for(int i=0; i<len; ++i)
      {
        while((now!=root)&&(now->ch[s[i]-'a']==NULL))
          {
            now=now->fail;
          }
        if(now->ch[s[i]-'a']!=NULL)
          {
            now=now->ch[s[i]-'a'];
          }
        node *u=now;
        while(!u->vis)
          {
            u->vis=1;
            u=u->fail;
          }
      }
    return 0;
  }
}
//对于i串,如果end[i]->vis==1那么就有匹配,否则没有