AC自动机-Fail树

\(AC\) 自动机中的 \(Fail\)

引入:

思考一下 \(AC\) 自动机的匹配过程:

从第一个字符开始,每到达一个节点 \(x\) ,就从 \(x\) 开始不断跳 \(fail\) 到根

期间跳到的节点代表的串都在文本串中出现

进阶:

既然可以从文本串的每位开始向上跳 \(fail\) 找模式串结尾节点。

那么我们也可以从结尾节点逆着 \(fail\) 找文本串节点

从结尾节点开始逆着跳 \(fail\) ,期间跳到的文本串节点个数就是这个模式串在文本串中出现的次数。

因此,建立好 \(fail\) 指针后,只留下反向的 \(fail\) 指针作为边,就得到了 \(fail\) 树。

大概长这样:

这棵树是在一个 \(trie\) 的基础上产生的,所以这棵树上的每个点都是一个字符串的前缀,而且每个字符串的每个前缀在这棵树上都对应着一个点。

其次,由于 \(fail\) 指针,每个点父节点的字符串都是这个点字符串的后缀,并且树上没有更长的它的后缀。

沿着 \(x\)\(fail\) 祖先往上走会找到 \(x\) 节点的所有后缀节点

对于字符串 \(s\) ,在自动机里匹配到的所有节点的所有 \(fail\) 祖先就表示 \(s\) 的所有子串。

利用:

只要将 \(fail\) 树上每个属于文本串的结点权值置为 \(1\) ,那么节点 \(x\) 的子树总权值就是 \(x\) 代表的串在文本串中出现的次数。

求子树权值之和,一种实现方法就是 \(DFS\) 序+树状数组

权值算法就是:

void solve(char *s){
    int x=0;len=strlen(s);
    for(int i=0;i<len;i++){
        int now=s[i]-'a';
        while(!ch[x][now]&&x) x=fail[x];
        x=ch[x][now];
        add(dfn[x],1);//插入树状数组
    }
}
void dfs(int x){
    dfn[x]=++cnt; sizes[x]=1;
    for(int i=head[x];i;i=nxt[i]){
        int y=ver[i];dfs(y);
        sizes[x]+=sizes[y]; 
    }
}
void calc(int x){
    query(dfn[val[x]]+size[val[x]]-1)-query(dfn[val[x]]-1);
    //val为该字符串在节点中的位置 ,正常的进行树状数组查询即可。
}
int main(){
....    
    for(int i=1;i<=n;i++) printf("%d\n",calc(val[i]));//查询每一个字符串末尾对应的AC自动机上的节点标号即可
}

例题:

[NOI2011] 阿狸的打字机

一开始用所有串构造好 \(fail\) 树并求出 \(DFS\) 序。

对于一个 \((x,y)\) 这样的询问,只要在 \(fail\) 树上将所有属于串 \(y\) 的结点权值赋为 \(1\) ,查询串 \(x\) 末尾结点子树权值和即可。

我们可以把所有 \(y\) 串相同的询问放在一起处理,用链式结构(就像建图一样连起来)。

然后重新走一次 \(Trie\) 树的构造过程:

  1. 每走到一个结点就将这个节点权值 \(+1\)
  2. 遇到 \(B\) 就令当前节点权值-1并返回上一个节点
  3. 每次遇到 \(P\) 时就可以查看是否有串 \(y\) 为当前字符串的询问,有就一起处理

这样做就保证了遇到 \(P\) 时只有当前字符串全部赋了值,并且遍历完只用 \(O(n)\)

#include<bits/stdc++.h>
using namespace std;
#define lowbit(x) x&-x
const int N=2e6+5,M=1e5+5;
int c[N][26],val[N],fail[N],cnt;
int nxt[N],ver[N],head[N],tot;
int fa[N],dfn[N],dep[N],sizes[N],son[N],top[N],rk[N],dfstime;

int n,Q,a[N],treesum[N]; char str[N]; 

void add(int x,int y){
    ver[++tot]=y; nxt[tot]=head[x]; head[x]=tot;
}
void ins(char *s,int id){
    int len=strlen(s);int now=0;
    for(int i=0;i<len;i++){
        int x=s[i]-'a';
        if(!c[now][x]) c[now][x]=++cnt;
        now=c[now][x];//指向地址
    }
    val[id]=now;//单词末尾
}
void build(){
    queue<int> q;
    for(int i=0;i<26;i++)//第一层是根节点0
        if(c[0][i])//遍历第二层,搜索存在的子树
            fail[c[0][i]]=0,q.push(c[0][i]);//fail指向根节点
    while(!q.empty()){
        int x=q.front();q.pop();
        for(int i=0;i<26;i++){//搜索这一棵树
            if(c[x][i]){
                fail[c[x][i]]=c[fail[x]][i];//让这个节点的失败指针指向(((他父亲节点)的失败指针所指向的那个节点)的下一个节点)
                q.push(c[x][i]);//存在子树,就压入队列
            }
            else
                c[x][i]=c[fail[x]][i];//否则就让这个子节点指向当前节点fail指针的子节点 
        } 
    }
}
void dfs1(int x,int father){
    sizes[x]=1; 
    fa[x]=father;
    dep[x]=dep[father]+1;
    for(int i=head[x];i;i=nxt[i]){
        int y=ver[i]; if(y==father) continue;
        // fa[y]=x;
        dfs1(y,x);
        sizes[x]+=sizes[y];
        if(!son[x]||sizes[y]>sizes[son[x]]) son[x]=y;
    }
}
void dfs2(int x,int topfather){
    dfn[x]=++dfstime;//dfs序
    top[x]=topfather;//这个点所在重链的顶端,对于求lca和链有极大帮助
    if(!son[x]) return;
    dfs2(son[x],topfather);//我们首先进入重儿子来保证一条重链上各个节点dfs序连续
    for(int i=head[x];i;i=nxt[i]){
        int y=ver[i];
        if(y!=son[x]&&y!=fa[x]) dfs2(y,y);//位于轻链底端,top为本身
    }
}
int lca(int x,int y){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    return x;
}

void update(int x,int val){
    for(;x<=dfstime;x+=lowbit(x)) treesum[x]+=val;
}
int query(int x){
    int res=0;
    for(;x;x-=lowbit(x)) res+=treesum[x]; return res;
}

bool cmp(int x,int y){return dfn[x]<dfn[y];}

inline void solve1(){
    int x=0,tp=0;
    for(int i=0;str[i];i++){
        x=c[x][str[i]-'a']; a[++tp]=x;//在 trie 树中统计走过的节点
    }
    sort(a+1,a+1+tp,cmp);
    bool flag=false;
    for(int i=1;i<=tp;i++){
        update(dfn[a[i]],1);//相邻两个节点在树上的位置 +1 ,表示多一个串匹配
        if(flag) update(dfn[lca(a[i],a[i-1])],-1);
        else flag=true;
    }
}

int main(){
    cin>>n;
    for(int i=1;i<=n;i++) scanf("%s",str),ins(str,i);
    build();
    for(int i=1;i<=cnt;i++) add(fail[i],i);
    dfs1(0,cnt+1); dfs2(0,0);
    cin>>Q;
    while(Q--){ int opt;
        scanf("%d",&opt);
        if(opt==1) scanf("%s",str),solve1();
        else if(opt==2){ 
            int x; scanf("%d",&x);
            printf("%d\n",query(dfn[val[x]]+sizes[val[x]]-1)-query(dfn[val[x]]-1));
        }       
    }
    system("pause");
    return 0;
}

posted @ 2021-09-23 21:48  Evitagen  阅读(659)  评论(0编辑  收藏  举报