【bzoj3473】字符串 【后缀自动机+树状数组】

题意:给定n个字符串,询问每个字符串有多少子串(不包括空串)是所有n个字符串中至少k个字符串的子串?(本质相同重复计算)
题解:首先我们把这n个字符串的广义后缀自动机建立出来,然后处理出每个状态出现在n个串的多少个之中。接着把每个串在后缀自动机跑一遍,统计即可。
如何处理出每个状态出现在n个串的多少个之中?
如果一个状态x出现在某个串y中,那么fail[x]一定也出现在y中,因为fail[x]是x的一个后缀。设val[i]代表状态i来自于哪个串。所以如果我们把fail链倒过来建一棵树,状态x出现在n个串之中的个数就是x的子树中的val的不同个数。我们就可以把这棵树dfs一次,处理出每个节点的dfs序区间,就把这个问题转化为了查询一个区间有多少个不同的数字,跟HH的项链那题一模一样。用树状数组处理一下就好了。
如何统计?
设ans[x]表示x出现在n个串之中的个数。只需要对每个串在后缀自动机上走,如果ans[now]小于k的话now就不停地跳fail。这就相当于把当前匹配到的不停地截短。然后答案累加上len[now]即可。至于为什么,请读者自行思考。而且匹配的过程中,不会失配,这也很显然。
时间复杂度: n log n
代码实现:

#include<cstdio>
#include<algorithm>
#include<string>
#include<cstring>
#include<vector>
using namespace std;
const int N=100005;
int n,k,l;
string s[N];
char str[N];
bool cmp(int a,int b);
struct SAM{
    int last,tot,len[N*2],fail[N*2],val[N*2],ch[N*2][26],c[N*2],a[N*2];
    int idx,in[N*2],out[N*2],pos[N*2],nxt[N*2],ck[N*2],ans[N*2];
    vector<int> e[N*2];
    SAM(){
        last=tot=1;
    }
    void insert(int x,int id){
        int p=last,np=++tot;
        len[np]=len[p]+1;
        last=np;
        val[np]=id;
        for(;p&&!ch[p][x];p=fail[p]){
            ch[p][x]=np;
        }
        if(!p){
            fail[np]=1;
        }else{
            int q=ch[p][x];
            if(len[q]==len[p]+1){
                fail[np]=q;
            }else{
                int nq=++tot;
                len[nq]=len[p]+1;
                memcpy(ch[nq],ch[q],sizeof(ch[q]));
                fail[nq]=fail[q];
                fail[q]=fail[np]=nq;
                for(;p&&ch[p][x]==q;p=fail[p]){
                    ch[p][x]=nq;
                }
            }
        }
    }
    void dfs(int u){
        in[u]=++idx;
        pos[idx]=u;
        for(int i=0;i<e[u].size();i++){
            dfs(e[u][i]);
        }
        out[u]=idx;
    }
    int lowbit(int x){
        return x&(-x);
    }
    void add(int i){
        while(i<=tot){
            c[i]++;
            i+=lowbit(i);
        }
    }
    int sum(int i){
        int res=0;
        while(i){
            res+=c[i];
            i-=lowbit(i);
        }
        return res;
    }
    void build(){
        for(int i=2;i<=tot;i++){
            e[fail[i]].push_back(i);
        }
        dfs(1);
        for(int i=1;i<=tot;i++){
            a[i]=i;
        }
        sort(a+1,a+tot+1,cmp);
        for(int i=tot;i>=1;i--){
            if(val[pos[i]]){
                nxt[i]=ck[val[pos[i]]];
                ck[val[pos[i]]]=i;
            }
        }
        for(int i=1;i<=tot;i++){
            if(ck[i]){
                add(ck[i]);
            }
        }
        for(int i=1,j=1;i<=tot;i++){
            while(j<in[a[i]]){
                if(nxt[j]){
                    add(nxt[j]);
                }
                j++;
            }
            ans[a[i]]=sum(out[a[i]])-sum(in[a[i]]-1);
        }
    }
    long long query(const char *s,int l){
        long long res=0;
        int now=1;
        for(int i=0;i<l;i++){
            now=ch[now][s[i]-'a'];
            while(now&&ans[now]<k){
                now=fail[now];
            }
            if(!now){
                now=1;
                continue;
            }
            res+=len[now];
        }
        return res;
    }
}sam;
bool cmp(int a,int b){
    return sam.in[a]==sam.in[b]?sam.out[a]<sam.out[b]:sam.in[a]<sam.in[b];
}
int main(){
    scanf("%d%d",&n,&k);
    for(int i=1;i<=n;i++){
        scanf("%s",str);
        s[i]=str;
        l=strlen(str);
        sam.last=1;
        for(int j=0;j<l;j++){
            sam.insert(str[j]-'a',i);
        }
    }
    sam.build();
    for(int i=1;i<=n;i++){
        printf("%lld ",sam.query(s[i].c_str(),s[i].size()));
    }
    puts("");
    return 0;
}
posted @ 2018-04-08 21:20  一剑霜寒十四洲  阅读(128)  评论(0编辑  收藏  举报