POJ 3415:后缀数组+单调栈优化

题意很简单,求两个字符串长度大于等于K的子串个数

一开始还是只会暴力。。发现n^2根本没法做。。。看了题解理解了半天才弄出来,太弱了。。。

思路:把两个字符串连接后做一次后缀数组,求出height

暴力的想法自然是枚举第一个子串的起始位置i和第二个子串的起始位置j,肯定会T的

看了题解才知道有单调栈这种有优化方法。。

将后缀分为A组(起始点为第一个字符串)、B组

设符合要求的lcp长度为a,则其对答案的贡献为a-k+1(长度为k~a的都是符合要求的)

一开始这里我也是有疑问的,比如说k=1,aa和aab两个suffix,lcp=2,这样的话贡献为2-1+1=2

但是我们可以看到有这两个suffix符合要求的是有3个的(长度1的2个,2的1个),我们似乎只统计了长度1的和长度2的各一个

确实如此。但是我们不能小范围的看问题,应该要结合整个sa[]和height[]

这里没统计不意味着之后不统计,还会有一个a后缀和一个ab后缀,这里又贡献了1.这样就补齐了~

我们知道两个suffix的lcp是height中对应段height[_rank[i]+1]~height[_rank[j]]的最小值,应用这个性质我们可以来维护单调递增的栈

 

为什么要这样呢?可以理解为栈里是可能被用到的候选序列,如果当前扫描到的height小于栈顶(候选最大值),则根据上面的性质,

可以得出大于height的值是无法做出贡献了(或者说贡献变小了),那累加器的值要更新

同时我们这里用到了一个优化,为了防止栈内元素过多,我们把同一个height值,捆绑一个个数num,这样可以提升统计效率

 

我们对A、B组各做一次,加起来就是答案

一开始我在这里没有完全理解,总感觉会重复统计。其实是没有的。为什么呢?

那统计B组的时候为例。

对于B组后缀j,我统计答案都是在sa[j]之前找。

比如说找到A组中的ii,jj,kk三个后缀是符合的,那必定有sa[ii]、sa[jj]、sa[kk]都小于sa[j]

所以在统计A组时,sa[ii]也是在sa[ii]之前找,不可能找到sa[j]

#include"cstdio"
#include"queue"
#include"cmath"
#include"stack"
#include"iostream"
#include"algorithm"
#include"cstring"
#include"queue"
#include"map"
#include"set"
#include"vector"
#define LL long long
#define mems(a,b) memset(a,b,sizeof(a))
#define ls pos<<1
#define rs pos<<1|1
#define max(a,b) (a)>(b)?(a):(b)
using namespace std;

const int MAXN = 200500;
const int INF = 0x3f3f3f3f;

char a[MAXN],b[MAXN];
int sa[MAXN],_rank[MAXN];
int wa[MAXN],wb[MAXN],wv[MAXN],Ws[MAXN];
int cmp(int *r,int a,int b,int l)
{return r[a]==r[b]&&r[a+l]==r[b+l];}
void get_sa(char *r,int *sa,int n,int m){
    int i,j,p,*x=wa,*y=wb,*t;
    for(i=0;i<m;i++) Ws[i]=0;
    for(i=0;i<n;i++) Ws[x[i]=r[i]]++;
    for(i=1;i<m;i++) Ws[i]+=Ws[i-1];
    for(i=n-1;i>=0;i--) sa[--Ws[x[i]]]=i;
    for(j=1,p=1;p<n;j*=2,m=p){
        for(p=0,i=n-j;i<n;i++) y[p++]=i;
        for(i=0;i<n;i++) if(sa[i]>=j) y[p++]=sa[i]-j;
        for(i=0;i<n;i++) wv[i]=x[y[i]];
        for(i=0;i<m;i++) Ws[i]=0;
        for(i=0;i<n;i++) Ws[wv[i]]++;
        for(i=1;i<m;i++) Ws[i]+=Ws[i-1];
        for(i=n-1;i>=0;i--) sa[--Ws[wv[i]]]=y[i];
        for(t=x,x=y,y=t,p=1,x[sa[0]]=0,i=1;i<n;i++)
            x[sa[i]]=cmp(y,sa[i-1],sa[i],j)?p-1:p++;
    }
    return;
}

int height[MAXN];
void get_height(char *r,int *sa,int n){
    int i,j,k=0;
    for(i=1;i<=n;i++) _rank[sa[i]]=i;
    for(i=0;i<n;height[_rank[i++]]=k)
    for(k?k--:0,j=sa[_rank[i]-1];r[i+k]==r[j+k];k++);
    return;
}

int k;
int main(){
    //freopen("in.txt","r",stdin);
    while(scanf("%d",&k)&&k){
        scanf("%s%s",a,b);
        int la=strlen(a);
        int lb=strlen(b);
        a[la]='$';
        for(int i=0;i<lb;i++) a[i+la+1]=b[i];
        int n=la+lb;    //最后一个元素的下标
        a[++n]='\0';
        get_sa(a,sa,n+1,300);
        get_height(a,sa,n);
        LL ans=0,cnt=0;         //cnt为累加器
        int top=0;
        pair<int,int> s[MAXN];
        for(int i=1;i<=n;i++){
            if(height[i]<k) top=0,cnt=0;
            else{
                int num=0;      //统计同一height值的个数
                if(sa[i-1]<la) num++,cnt+=height[i]-k+1;
                while(top&&height[i]<=s[top].first){
                    cnt-=s[top].second*(s[top].first-height[i]);    //如果栈顶元素的height大于等于当前height,则贡献会变小s[top].first-height[i],更新累加器
                    num+=s[top--].second;                           //更新个数
                }
                s[++top]=make_pair(height[i],num);
                if(sa[i]>la) ans+=cnt;
            }
        }
        top=0;
        cnt=0;
        for(int i=1;i<=n;i++){
            if(height[i]<k) top=0,cnt=0;
            else{
                int num=0;
                if(sa[i-1]>la) num++,cnt+=height[i]-k+1;
                while(top&&height[i]<=s[top].first){
                    cnt-=s[top].second*(s[top].first-height[i]);
                    num+=s[top--].second;
                }
                s[++top]=make_pair(height[i],num);
                if(sa[i]<la) ans+=cnt;
            }
        }
        printf("%lld\n",ans);
    }
    return 0;
}
View Code

 

posted @ 2016-03-12 23:53  Septher  阅读(843)  评论(0编辑  收藏  举报