poj 3415 Common Substrings

http://poj.org/problem?id=3415

 

题意:求两个字符串长度不小于k的公共子串数量

 

两个字符串用特殊字符连起来

后缀数组求出height数组

从大到小枚举,并查集合并

记录每一组 特殊字符前有多少个,特殊字符后有多少个,合并的贡献是 两者的乘积*(当前height-m+1)

 

#include<cstdio>
#include<vector>
#include<cstring>
#include<iostream>
#include<algorithm>

using namespace std;

#define N 100001

int n1,n,m;
char s[N<<1];

int p=0,q=1;
int v[N<<1];
int sa[2][N<<1],rk[2][N<<1],height[N<<1];

int fa[N<<1],siza[N<<1],sizb[N<<1];

vector<int>V[N<<1];

long long ans;

void mul(int k,int *sa,int *rk,int *SA,int *RK)
{
    for(int i=1;i<=n;++i) v[rk[sa[i]]]=i;
    for(int i=n;i;--i) if(sa[i]>k) SA[v[rk[sa[i]-k]]--]=sa[i]-k;
    for(int i=n-k+1;i<=n;++i) SA[v[rk[i]]--]=i;
    for(int i=1;i<=n;++i) RK[SA[i]]=RK[SA[i-1]]+(rk[SA[i]]!=rk[SA[i-1]]||rk[SA[i]+k]!=rk[SA[i-1]+k]);
}

void presa()
{
    memset(v,0,sizeof(v));
    for(int i=1;i<=n;++i) v[s[i]]++;
    for(int i=1;i<=130;++i) v[i]+=v[i-1];
    for(int i=1;i<=n;++i) sa[p][v[s[i]]--]=i;
    for(int i=1;i<=n;++i) rk[p][sa[p][i]]=rk[p][sa[p][i-1]]+(s[sa[p][i-1]]!=s[sa[p][i]]);
    for(int k=1;k<n;k<<=1,swap(p,q)) mul(k,sa[p],rk[p],sa[q],rk[q]);
}

void get_height()
{
    int j;
    for(int k=0,i=1;i<=n;++i)
    {
        j=sa[p][rk[p][i]-1];
        while(s[j+k]==s[i+k]) k++;
        height[rk[p][i]]=k;
        if(k) k--;
    }
}

int find(int i) { return fa[i]==i ? i : fa[i]=find(fa[i]); }

void unionn(int x,int y,int i)
{
    x=find(x);
    y=find(y);
    ans+=1LL*siza[x]*sizb[y]*(i-m+1);
    ans+=1LL*sizb[x]*siza[y]*(i-m+1);
    siza[y]+=siza[x];
    sizb[y]+=sizb[x];
    fa[x]=y;
}

void solve()
{
    for(int i=1;i<=n;++i) fa[i]=i;
    for(int i=1;i<=n1;++i) siza[i]=1,sizb[i]=0;
    for(int i=n1+2;i<=n;++i) sizb[i]=1,siza[i]=0;
    int mx=0;
    for(int i=2;i<=n;++i) V[height[i]].push_back(i),mx=max(mx,height[i]);
    int s,w;
    ans=0;
    for(int i=n;i>=m;--i)
    {
        s=V[i].size();
        for(int j=0;j<s;++j)
        {
            w=V[i][j];
            if(find(sa[p][w-1])!=find(sa[p][w])) unionn(sa[p][w-1],sa[p][w],i);
            if(w<n && height[w+1]>=i && find(sa[p][w+1])!=find(sa[p][w])) unionn(sa[p][w+1],sa[p][w],i);
        }
    }
    cout<<ans<<'\n';
    for(int i=0;i<=mx;++i) V[i].clear();
}

int main()
{
    while(1)
    {
        scanf("%d",&m);
        if(!m) return 0;
        scanf("%s",s+1);
        n1=n=strlen(s+1);
        s[n+1]=char('a'+26);
        scanf("%s",s+n+2);
        n+=strlen(s+n+2)+1;
        presa();
        get_height();
        solve();
    }
}    

 

posted @ 2018-03-05 11:22  TRTTG  阅读(190)  评论(0编辑  收藏  举报