【BZOJ3277/3473】串/字符串 后缀数组+二分+RMQ+双指针

【BZOJ3277】串

Description

字符串是oi界常考的问题。现在给定你n个字符串,询问每个字符串有多少子串(不包括空串)是所有n个字符串中至少k个字符串的子串(注意包括本身)。

Input

第一行两个整数n,k。   接下来n行每行一个字符串。

Output

输出一行n个整数,第i个整数表示第i个字符串的答案。

Sample Input

3 1 abc a ab

Sample Output

6 1 3

HINT

对于100%的数据,n,k,l<=100000

题解:需要的用的方法好像有点多,但是也比我一开始自己yy的要少,我一开始yy的是后缀数组+主席树+线段树(233)

首先用到这样一个结论,就是如果第i个后缀有x个前缀能被k个串包含,那么第i+1个后缀至少有x-1个前缀能被k个串包含(与height数组的求法类似~)

那么我们先预处理这样一个东西,ls[i]代表从最大的j使得[j,i]中包含k个串([j,i]我指的是height数组上的一段区间)。这个可以用双指针法直接搞。

然后我们仿照height数组的求法,假设后缀i-1有x个合法前缀,那么到第i个后缀的时候我们就从x开始向上枚举,用二分+ST表找出它左(右)边第一个height比它小的位置,再用ls判断一下这段区间是否包含k个串就行了,时间复杂度O(nlogn)。

#include <cstdio> 
#include <iostream> 
#include <cstring> 
#define lson x<<1 
#define rson x<<1|1 
using namespace std; 
const int maxn=200010; 
int num,n,k,len,m,t,sum; 
int r[maxn],sa[maxn],st[maxn],ra[maxn],rb[maxn],h[maxn],rank[maxn],bel[maxn],s[maxn],last[maxn]; 
char str[maxn]; 
int f[maxn][20],Log[maxn],ls[maxn],v[maxn]; 
long long ans[maxn]; 
void work() 
{ 
    int i,j,p,*x=ra,*y=rb; 
    for(i=0;i<n;i++) st[x[i]=r[i]]++; 
    for(i=1;i<m;i++) st[i]+=st[i-1]; 
    for(i=n-1;i>=0;i--)  sa[--st[x[i]]]=i; 
    for(j=p=1;p<n;j<<=1,m=p) 
    { 
        for(p=0,i=n-j;i<n;i++)   y[p++]=i; 
        for(i=0;i<n;i++) if(sa[i]>=j) y[p++]=sa[i]-j; 
        for(i=0;i<m;i++) st[i]=0; 
        for(i=0;i<n;i++) st[x[y[i]]]++; 
        for(i=1;i<m;i++) st[i]+=st[i-1]; 
        for(i=n-1;i>=0;i--)  sa[--st[x[y[i]]]]=y[i]; 
        for(swap(x,y),i=p=1,x[sa[0]]=0;i<n;i++) 
            x[sa[i]]=(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+j]==y[sa[i-1]+j])?p-1:p++; 
    } 
    for(i=1;i<n;i++) rank[sa[i]]=i; 
    for(i=p=0;i<n-1;h[rank[i++]]=p) 
        for(p?p--:0,j=sa[rank[i]-1];r[i+p]==r[j+p];p++); 
} 
int query(int a,int b) 
{ 
    if(a>b)  return 1<<30; 
    int c=Log[b-a+1]; 
    return min(f[a][c],f[b-(1<<c)+1][c]); 
} 
int check(int x,int y) 
{ 
    int L=x,R=n,mid; 
    while(L<R) 
    { 
        mid=L+R>>1; 
        if(query(x+1,mid)>=y)    L=mid+1; 
        else    R=mid; 
    } 
    if(query(ls[L-1]+1,L-1)>=y)  return 1; 
    return 0; 
} 
int main() 
{ 
    scanf("%d%d",&num,&k); 
    int i,j,a,p; 
    for(i=1;i<=num;i++) 
    { 
        scanf("%s",str),a=strlen(str); 
        for(j=0;j<a;j++) bel[n]=i,r[n++]=str[j]-'a'+num; 
        last[i]=n,r[n++]=num-i; 
    } 
    m=26+num; 
    work(); 
    for(i=num;i<n;i++)   f[i][0]=h[i]; 
    for(i=2;i<=n;i++)    Log[i]=Log[i>>1]+1; 
    for(j=1;(1<<j)<n;j++) 
        for(i=num;i+(1<<j)-1<=n;i++)   f[i][j]=min(f[i][j-1],f[i+(1<<j-1)][j-1]); 
    for(i=num;i<n&&sum<k;i++) sum+=!s[bel[sa[i]]],s[bel[sa[i]]]++; 
    ls[i-1]=num; 
    for(j=num;i<n;i++) 
    { 
        sum+=!s[bel[sa[i]]],s[bel[sa[i]]]++; 
        for(;j<=i&&sum>=k;j++)    s[bel[sa[j]]]--,sum-=!s[bel[sa[j]]]; 
        ls[i]=j-1; 
    } 
    for(i=p=0;i<n;i++) 
    { 
        if(!bel[i]) continue; 
        for(p?p--:0;p<last[bel[i]]-i&&check(rank[i],p+1);p++); 
        ans[bel[i]]+=p; 
    } 
    for(i=1;i<num;i++)   printf("%lld ",ans[i]); 
    printf("%lld",ans[num]); 
    return 0; 
}

 

posted @ 2017-06-06 17:18  CQzhangyu  阅读(379)  评论(0编辑  收藏  举报