BZOJ 2434 阿狸的打字机(ac自动机+dfs序+树状数组)

题意

给你一些串,还有一些询问
问你第x个串在第y个串中出现了多少次

思路

对这些串建ac自动机
根据fail树的性质:若x节点是trie中root到t任意一个节点的fail树的祖先,那么x一定是y的子串
而x在y中出现的次数为以x为fail树中的根节点的子树中,有多少个节点是trie树中根节点到y的
首先对询问离线
由于这题是一个节点一个节点建的ac自动机,所以我们可以根据这个建立的路径来维护一个当前节点到根的路径
路径上的点权为1,其余为0
每次到达一个询问到的y,我们就查询此时x的fail树中的子树和,这个可以预处理出fail树的dfs序用树状数组维护

但是如果排除这题的特殊性,要通过遍历trie的方式来更新答案
因为在操作fail的过程中,trie改变了

else trie[now][i]=trie[fail[now]][i];

也就是说bfs构造fail指针的时候,有时候会让trie指向之前的节点,也就是(trie中)高度比较小的节点
所以dfs只需要这样写

void dfs(int x){
    //printf("==%d\n",x);
    add(bg[x],1);
    for(int i = 0; i < (int)Q[x].size(); i++){
        int y = Q[x][i].fst;
        ans[Q[x][i].sc]=sum(ed[y])-sum(bg[y]-1);
    }
    for(int i = 0; i < 26; i++){
        int y = ac.trie[x][i];
        if(!y||h[y]<=h[x])continue;
        dfs(y);
    }
    add(bg[x],-1);
}

代码

版本一

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<string>
#include<stack>
#include<queue>
#include<deque>
#include<set>
#include<vector>
#include<map>
#include<functional>
    
#define fst first
#define sc second
#define pb push_back
#define mem(a,b) memset(a,b,sizeof(a))
#define lson l,mid,root<<1
#define rson mid+1,r,root<<1|1
#define lc root<<1
#define rc root<<1|1

using namespace std;

typedef double db;
typedef long double ldb;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> PI;
typedef pair<ll,ll> PLL;

const db eps = 1e-6;
const int mod = 1e9+7;
const int maxn = 3e5+100;
const int maxm = 2e6+100;
const int inf = 0x3f3f3f3f;
const db pi = acos(-1.0);

vector<int>v[maxn];
int id[maxn];
struct AC{
    //局部变量没有默认0!
    int trie[maxn][26];
    int num[maxn];//单词出现次数
    int fail[maxn],fa[maxn];
    int vis[maxn];//ask函数用到
    int tot,rt;
    int ntot;
    //多测可写个init
    void init(){tot=0;mem(vis,0);mem(trie,0);rt=0;ntot=0;}
    void add(char c){
        if(c=='P'){
            num[rt]++;
            id[++ntot]=rt;
            return;
        }
        if(c=='B'){
            rt=fa[rt];
            return;
        }
        int x = c-'a';
        if(!trie[rt][x]){
            trie[rt][x]=++tot;
            fa[tot]=rt;
        }
        //fa[trie[rt][x]]=rt;
        rt=trie[rt][x];
    }
    void build(){
        queue<int>q;
        for(int i = 0; i < 26; i++){
            if(trie[0][i]){
                //fail[trie[0][i]]=0;
                //v[0].pb(trie[0][i]);
                q.push(trie[0][i]);
            }
        }
        while(!q.empty()){
            int now = q.front();
            q.pop();
            for(int i = 0; i < 26; i++){
                ////让这个节点的失败指针指向(((他父亲节点)的失败指针所指向的那个节点)的下一个相同节点)
                if(trie[now][i]){
                    fail[trie[now][i]]=trie[fail[now]][i];
                    //v[trie[fail[now]][i]].pb(trie[now][i]);
                    q.push(trie[now][i]);
                }
                else trie[now][i]=trie[fail[now]][i];
                //否则就让当前节点的这个子节点指向当前节点fail指针的这个子节点   
            }
            v[fail[now]].pb(now);
        }
    }
    
}ac;
char a[maxn];
int bg[maxn],ed[maxn];
int S[maxn],tot;
void gao(int x){
    S[++tot]=x;
    bg[x]=tot;
    for(int i = 0; i < (int)v[x].size(); i++){
        int y =v[x][i];
        gao(y);
    }
    //S[++tot]=x;
    ed[x]=tot;
}
int tree[maxn];
int lowbit(int x){return x&-x;}
void add(int x, int c){
    for(int i = x; i <= tot; i+=lowbit(i)){tree[i]+=c;}
}
int sum(int x){
    int ans =0;
    for(int i = x; i; i-=lowbit(i))ans+=tree[i];
    return ans;
}
vector<PI>Q[maxn];
int ans[maxn];
/*
void dfs(int x){
    //printf("%d\n",x);
    add(bg[x],1);
    for(int i = 0; i < (int)Q[x].size(); i++){
        int y = Q[x][i].fst;
        ans[Q[x][i].sc]=sum(ed[y])-sum(bg[y]-1);
    }
    for(int i = 0; i < 26; i++){
        int y = ac.trie[x][i];
        if(!y||(y==ac.trie[ac.fail[x]][i]))continue;
        dfs(y);
    }
    add(bg[x],-1);
}*/
void sv(){
    int now = 0;
    int n = strlen(a+1);
    int world=0;
    for(int i = 1; i <= n; i++){
        if(a[i]=='B'){
            add(bg[now],-1);
            now=ac.fa[now];
            continue;
        }
        else if(a[i]=='P'){
            world++;
            for(int j = 0; j < (int)Q[now].size(); j++){
                int x = Q[now][j].fst;
                int y = Q[now][j]. sc;
                ans[y]=sum(ed[x])-sum(bg[x]-1);
            }
            continue;
        }
        else{
            int x=a[i]-'a';
            now=ac.trie[now][x];
            add(bg[now],1);
        }
    }
}
int main(){
    ac.init();
    tot=0;
    scanf("%s", a+1);
    int n = strlen(a+1);
    for(int i = 1; i <= n; i++){
        ac.add(a[i]);
    }
    ac.build();
    int m;
    scanf("%d", &m);
    for(int i = 1; i <= m; i++){
        int x,y;
        scanf("%d %d", &x, &y);
        x=id[x];y=id[y];
        Q[y].pb(make_pair(x,i));
    }
    gao(0);
    //dfs(0);
    sv();
    for(int i = 1; i <= m; i++){
        printf("%d\n",ans[i]);
    }
    return 0;
}
/*
aPaPBbP
3
1 2
1 3
2 3

a
aa
aa
 */

版本二

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<string>
#include<stack>
#include<queue>
#include<deque>
#include<set>
#include<vector>
#include<map>
#include<functional>
     
#define fst first
#define sc second
#define pb push_back
#define mem(a,b) memset(a,b,sizeof(a))
#define lson l,mid,root<<1
#define rson mid+1,r,root<<1|1
#define lc root<<1
#define rc root<<1|1
 
using namespace std;
 
typedef double db;
typedef long double ldb;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> PI;
typedef pair<ll,ll> PLL;
 
const db eps = 1e-6;
const int mod = 1e9+7;
const int maxn = 3e5+100;
const int maxm = 2e6+100;
const int inf = 0x3f3f3f3f;
const db pi = acos(-1.0);
 
vector<int>v[maxn];
int id[maxn];
int h[maxn];
struct AC{
    //局部变量没有默认0!
    int trie[maxn][26];
    int num[maxn];//单词出现次数
    int fail[maxn],fa[maxn];
    int vis[maxn];//ask函数用到
    int tot,rt;
    int ntot;
    int H;
    //多测可写个init
    void init(){tot=0;mem(vis,0);mem(trie,0);rt=0;ntot=0;H=0;}
    void add(char c){
        if(c=='P'){
            num[rt]++;
            id[++ntot]=rt;
            return;
        }
        if(c=='B'){
            rt=fa[rt];H--;
            return;
        }
        int x = c-'a';
        if(!trie[rt][x]){
            trie[rt][x]=++tot;
            fa[tot]=rt;
        }
        //fa[trie[rt][x]]=rt;
        rt=trie[rt][x];
        h[rt]=++H;
    }
    void build(){
        queue<int>q;
        for(int i = 0; i < 26; i++){
            if(trie[0][i]){
                fail[trie[0][i]]=0;
                //v[0].pb(trie[0][i]);
                q.push(trie[0][i]);
            }
        }
        while(!q.empty()){
            int now = q.front();
            q.pop();
            for(int i = 0; i < 26; i++){
                ////让这个节点的失败指针指向(((他父亲节点)的失败指针所指向的那个节点)的下一个相同节点)
                if(trie[now][i]){
                    fail[trie[now][i]]=trie[fail[now]][i];
                    //v[trie[fail[now]][i]].pb(trie[now][i]);
                    q.push(trie[now][i]);
                }
                else trie[now][i]=trie[fail[now]][i];
                //否则就让当前节点的这个子节点指向当前节点fail指针的这个子节点   
            }
            v[fail[now]].pb(now);
        }
    }
     
}ac;
char a[maxn];
int bg[maxn],ed[maxn];
int S[maxn],tot;
void gao(int x){
    S[++tot]=x;
    bg[x]=tot;
    for(int i = 0; i < (int)v[x].size(); i++){
        int y =v[x][i];
        gao(y);
    }
    //S[++tot]=x;
    ed[x]=tot;
}
int tree[maxn];
int lowbit(int x){return x&-x;}
void add(int x, int c){
    for(int i = x; i <= tot; i+=lowbit(i)){tree[i]+=c;}
}
int sum(int x){
    int ans =0;
    for(int i = x; i; i-=lowbit(i))ans+=tree[i];
    return ans;
}
vector<PI>Q[maxn];
int ans[maxn];
 
void dfs(int x){
    //printf("==%d\n",x);
    add(bg[x],1);
    for(int i = 0; i < (int)Q[x].size(); i++){
        int y = Q[x][i].fst;
        ans[Q[x][i].sc]=sum(ed[y])-sum(bg[y]-1);
    }
    for(int i = 0; i < 26; i++){
        int y = ac.trie[x][i];
        if(!y||h[y]<=h[x])continue;
        dfs(y);
    }
    add(bg[x],-1);
}
void sv(){
    int now = 0;
    int n = strlen(a+1);
    int world=0;
    for(int i = 1; i <= n; i++){
        if(a[i]=='B'){
            add(bg[now],-1);
            now=ac.fa[now];
            continue;
        }
        else if(a[i]=='P'){
            world++;
            for(int j = 0; j < (int)Q[now].size(); j++){
                int x = Q[now][j].fst;
                int y = Q[now][j]. sc;
                ans[y]=sum(ed[x])-sum(bg[x]-1);
            }
            continue;
        }
        else{
            int x=a[i]-'a';
            now=ac.trie[now][x];
            add(bg[now],1);
        }
    }
}
int main(){
    ac.init();
    tot=0;
    scanf("%s", a+1);
    int n = strlen(a+1);
    for(int i = 1; i <= n; i++){
        ac.add(a[i]);
    }
    ac.build();
    int m;
    scanf("%d", &m);
    for(int i = 1; i <= m; i++){
        int x,y;
        scanf("%d %d", &x, &y);
        x=id[x];y=id[y];
        Q[y].pb(make_pair(x,i));
    }
    gao(0);
    dfs(0);
    //sv();
    for(int i = 1; i <= m; i++){
        printf("%d\n",ans[i]);
    }
    return 0;
}
/*
asPdPasPddPhBdPhPnaPasP
8
1 5 
2 6 
3 8 
4 3
7 7 
2 5 
6 8
1 8
 
a
aa
aa
 */
posted @ 2019-09-27 13:41  wrjlinkkkkkk  阅读(172)  评论(0编辑  收藏  举报