【POJ.3415 Common Substrings】后缀数组 长度不小于K的公共子串个数

Common Substrings

题意

给出两个字符串,求他俩长度>=k的公共子串的数量。

思路

\(n^2\) 的思路比较容易想到。

我们把两个字符串用一个没有出现过的字符隔开拼接起来,做后缀数组。

那么公共子串的数量,就是A串的后缀和B串的后缀之间的所有最长公共前缀和。

统计时,遍历\(height\)数组。

对于第i个后缀,遍历\(j(j<i)\),计算\((i,j)\)的最长公共前缀

如果后缀\(i\)是A串中的,那么\(j\)就只统计B串,如果后缀\(i\)是B串,\(j\)只统计A串。

优化:
我们知道\(lcp(i,j)=min(height[i+1]...height[j])\)

根据这个可以知道对于后缀\(i\)统计的答案:

\(lcp(i,1) ,lcp(i,2) , lcp(i,3), ... , lcp(i,i-1)\)

是非递减的。

知道这一点,先统计后缀\(i\)为B串中的后缀的答案。

维护一个\(sum\),表示后缀\(i\)与其前面A串的最长公共前缀和。(后缀\(i\)

不一定是B串中的后缀)

假如现在已经前6个后缀更新完了,贡献如下:

1 2 3 4 5

\(height[7]==2\),那么对于后缀7来说,贡献就应该为:

1 2 2 2 2 2

此时我们维护一个单调递增栈和一个数组\(num\),单调栈中存放

上述贡献值\(num\)数组存放栈内的贡献值出现的次数

每次更新后缀的时候,遍历栈中贡献值>=\(height[i]\)的,

\(sum\) 减去 贡献值变小的部分* 贡献值数量,将\(height[i]\)入栈

如果此时后缀\(i\)为B串中的后缀,\(ans+=sum\)

然后统计为后缀\(i\)为A串的。

PS:

为什么要加一个字符隔开?

两个串分别为:aaaaa , aaaa;

如果不隔开,A串中的后缀5,和B串中的后缀1的公共前缀长度就是4,

但应该是1,所以要隔开。

代码

/*Gts2m ranks first in the world*/
#define pb push_back
#define stop system("pause")
#include<stdio.h>
#include<string.h>
#include<iostream>
#include<algorithm>
//#include<bits/stdc++.h>
using namespace std;
const int N=2e5+10;
typedef long long ll;
typedef unsigned long long ull;

char s[N],t[N];
int sa[N],rk[N],ht[N],oldrk[N],pos[N],cnt[N];
int n,m,lens,lent;
bool cmp(int a,int b,int k)
{
    return oldrk[a]==oldrk[b]&&oldrk[a+k]==oldrk[b+k];
}
void getsa()
{
    m=122;
    memset(cnt,0,sizeof(cnt));
    for(int i=1; i<=n; i++) ++cnt[rk[i]=s[i]];
    for(int i=1; i<=m; i++) cnt[i]+=cnt[i-1];
    for(int i=n; i; i--) sa[cnt[rk[i]]--]=i;
    for(int k=1; k<=n; k<<=1)
    {
        int num=0;
        for(int i=n-k+1; i<=n; i++) pos[++num]=i;
        for(int i=1; i<=n; i++) if(sa[i]>k) pos[++num]=sa[i]-k;
        memset(cnt,0,sizeof(cnt));
        for(int i=1; i<=n; i++) ++cnt[rk[i]];
        for(int i=1; i<=m; i++) cnt[i]+=cnt[i-1];
        for(int i=n; i; i--) sa[cnt[rk[pos[i]]]--]=pos[i];
        num=0;
        memcpy(oldrk,rk,sizeof(rk));
        for(int i=1; i<=n; i++) rk[sa[i]]=cmp(sa[i],sa[i-1],k)?num:++num;
        if(num==n) break;
        m=num;
    }
    for(int i=1; i<=n; i++) rk[sa[i]]=i;
    int k=0;
    for(int i=1; i<=n; i++)
    {
        if(k) --k;
        while(s[i+k]==s[sa[rk[i]-1]+k]) ++k;
        ht[rk[i]]=k;
    }
}
int sta[N],num[N];
int main()
{
    int k;
    while(~scanf("%d",&k)&&k)
    {
        scanf("%s%s",s+1,t+1);
        lens=strlen(s+1),lent=strlen(t+1);
        n=lens;
        s[++n]='$';
        for(int i=1; i<=lent; i++) s[++n]=t[i];
        s[++n]='%';
        getsa();
        ll ans=0,sum=0;
        int top=0;
        for(int i=1; i<=n; i++)
        {
            if(ht[i]<k)
                sum=0,top=0;
            else
            {
                int cnt=0;//cnt 表示应该等于ht[i]的贡献值的个数
                if(sa[i-1]<=lens+1)//后缀sa[i-1]是A串时
                {
                    cnt=1;
                    sum+=ht[i]-k+1;
                }
                while(top&&ht[i]<=ht[sta[top]])
                {
                    cnt+=num[top];
                    sum-=1LL*num[top]*(ht[sta[top]]-ht[i]);//更新sum值
                    top--;
                }
                sta[++top]=i,num[top]=cnt;
                if(sa[i]>lens+1)//当后缀sa[i]为B串时,更新答案
                    ans+=sum;
            }
        }
        //统计后缀i为A串时的值
        for(int i=1; i<=n; i++)
        {
            if(ht[i]<k)
                sum=0,top=0;
            else
            {
                int cnt=0;
                if(sa[i-1]>lens+1)
                {
                    cnt=1;
                    sum+=ht[i]-k+1;
                }
                while(top&&ht[i]<=ht[sta[top]])
                {
                    cnt+=num[top];
                    sum-=1LL*num[top]*(ht[sta[top]]-ht[i]);
                    top--;
                }
                sta[++top]=i,num[top]=cnt;
                if(sa[i]<=lens+1)
                    ans+=sum;
            }
        }
        printf("%lld\n",ans);
    }
    return 0;
}

posted @   Valk3  阅读(132)  评论(0编辑  收藏  举报
努力加载评论中...
点击右上角即可分享
微信分享提示