hdu 6096

hdu 6096
前缀l,后缀r,中间部分mid
一个字符串为k=l+mid+r,其中l和r没有相交的部分
题目给了一些字符串和一些前后缀,问每个前后缀能匹配多少字符串
我们构造字符串s=r+'^'+l,t=k+'^'+k
用s去建立ac自动机,对每个字符串k进行查询。
怎么保证不会重叠呢?如果不重叠的话,后缀+前缀的长度最长为字符串本身的长度,又因为还有一个
中间字符'^',所以if(deep[j]<=len) ans[j]++;,这里的len为字符串长度+1

#include <bits/stdc++.h>
#define inf 2333333333333333
#define N 2000010
#define p(a) putchar(a)
#define For(i,a,b) for(int i=a;i<=b;++i)

using namespace std;
int T,n,q,cnt;
int pos[N];
char str[N],*s[N],len[N],s1[N],s2[N];

void in(int &x){
    int y=1;char c=getchar();x=0;
    while(c<'0'||c>'9'){if(c=='-')y=-1;c=getchar();}
    while(c<='9'&&c>='0'){ x=(x<<1)+(x<<3)+c-'0';c=getchar();}
    x*=y;
}
void o(int x){
    if(x<0){p('-');x=-x;}
    if(x>9)o(x/10);
    p(x%10+'0');
}

struct ac{
    int tot;
    int fail[N],tr[N][27],deep[N],ans[N];

    void init(){
        memset(fail,0,sizeof(fail));
        memset(tr,0,sizeof(tr));
        memset(deep,0,sizeof(deep));
        memset(ans,0,sizeof(ans));
        tot=0;
    }

    int insert(char *s){
        int u=0;
        for(int i=0;s[i];i++){
            int temp=s[i]-'a';
            if(!tr[u][s[i]-'a']){
                tr[u][s[i]-'a']=++tot;
                deep[tot]=i+1;
            }
            u=tr[u][s[i]-'a'];
        }
        return u;
    }
    queue<int>q;
    void build(){
        For(i,0,26) 
            if(tr[0][i])
                q.push(tr[0][i]);
        while(!q.empty()){
            int u=q.front();q.pop();
            For(i,0,26)
                if(tr[u][i]){
                    fail[tr[u][i]]=tr[fail[u]][i];
                    q.push(tr[u][i]);
                }
                else
                    tr[u][i]=tr[fail[u]][i];
        }
    }
    void query(char *t,int len){
        int u=0;
        for(int i=0;t[i];i++){
            u=tr[u][t[i]-'a'];
            for(int j=u;j;j=fail[j]){
                if(deep[j]<=len) ans[j]++;
            }
        }
    }
}ac;

signed main(){
    in(T);
    while(T--){
        ac.init();
        in(n);in(q);
        cnt=0;
        For(i,1,n){
            s[i]=str+cnt;
            scanf("%s",s[i]);
            len[i]=strlen(s[i])+1;
            cnt+=len[i];
            strcpy(str+cnt,s[i]);
            str[cnt-1]='z'+1;
            cnt+=len[i];
        }
        For(i,1,q){
            s1[0]='z'+1;
            scanf("%s%s",s1+1,s2);
            strcat(s2,s1);
            pos[i]=ac.insert(s2);
        }
        ac.build();
        For(i,1,n) ac.query(s[i],len[i]);
        For(i,1,q) o(ac.ans[pos[i]]),p('\n');
    }
    return 0;
}

 

posted @ 2020-07-04 09:49  WeiAR  阅读(118)  评论(0编辑  收藏  举报