pku3415 Common Substrings
后缀数组+栈的线性扫描统计两个字符串的长度不少于K的公共子串个数
separate()构造原字符串str1和str2的高度数组,根据lcp(sa[i],sa[j])=rmq(height[i+1..j]),注意到各后缀在合并后得到的字符串中保持在原字符串中的字典序
#include <iostream>
using namespace std;
#define MAXN 200010
#define clr(x) memset(x,0,sizeof(x))
int b[MAXN],array[4][MAXN],*sa,*nsa,*rank,*nrank,
h1[MAXN],h2[MAXN],h3[MAXN],K,n1,n2,n3;
char str1[MAXN/2],str2[MAXN/2],str3[MAXN];
int st[MAXN];
void make_sa(){
int i,j,k;
sa=array[0];
nsa=array[1];
rank=array[2];
nrank=array[3];
memset(b,0,sizeof(b));
for(i=0;i<n3;i++)
b[str3[i]]++;
for(i=1;i<=256;i++)
b[i]+=b[i-1];
for(i=n3-1;i>=0;i--)
sa[--b[str3[i]]]=i;
for(rank[sa[0]]=0,i=1;i<n3;i++){
rank[sa[i]]=rank[sa[i-1]];
if(str3[sa[i]]!=str3[sa[i-1]])
rank[sa[i]]++;
}
for(k=1;k<n3 && rank[sa[n3-1]]<n3-1;k*=2){
for(i=0;i<n3;i++)
b[rank[sa[i]]]=i;
for(i=n3-1;i>=0;i--)
if(sa[i]-k>=0)
nsa[b[rank[sa[i]-k]]--]=sa[i]-k;
for(i=n3-k;i<n3;i++)
nsa[b[rank[i]]--]=i;
for(nrank[nsa[0]]=0,i=1;i<n3;i++){
nrank[nsa[i]]=nrank[nsa[i-1]];
if(rank[nsa[i]]!=rank[nsa[i-1]] || rank[nsa[i]+k]!=rank[nsa[i-1]+k])
nrank[nsa[i]]++;
}
int *t=sa;sa=nsa;nsa=t;
t=rank;rank=nrank;nrank=t;
}
for(i=0,k=0;i<n3;i++){
if(rank[i]==0)
h3[rank[i]]=0;
else{
for(j=sa[rank[i]-1];str3[i+k]==str3[j+k];k++);
h3[rank[i]]=k;
if(k>0)
k--;
}
}
}
__int64 solve(int h[],int n){
int i=0,t,top=0,m,num;
__int64 cnt=0;
st[top]=0;
h[0]=h[n]=K-1;
while(i<=n){
t=h[st[top]];//定义与height[st[top]]相关的后缀为栈顶后缀(乱编的)
if(h[i]<K && top==0)
i++;
else
if(h[i]>t)
st[++top]=i++;
else
if(h[i]==t)
i++;
else{
m=i-st[top]+1;
if(h[i]>=K && h[i]>h[st[top-1]]){
//统计与当前后缀和栈顶后缀的公共前缀不相关的
//且属于栈顶后缀的公共前缀的部分(再读一遍?)
num=t-h[i];
h[st[top]]=h[i];//下次统计当前后缀和栈顶后缀的长度为h[i]的公共前缀部分
}
else//当前后缀不能与栈顶后缀构成长度至少为K公共前缀,故暂不
//纳入统计,且看它与下一个后缀能否构成公共K前缀
num=t-h[st[--top]];
cnt+=(__int64)m*(m-1)/2*num;
}
}
return cnt;
}
void separate(){
clr(h1);
clr(h2);
int mn1=INT_MAX,mn2=INT_MAX,i,cnt1=0,cnt2=0;
for(i=0;i<n3;i++){
if(h3[i]<mn1)
mn1=h3[i];
if(h3[i]<mn2)
mn2=h3[i];
if(sa[i]<n1){
h1[cnt1++]=mn1;
mn1=INT_MAX;
}
else{
h2[cnt2++]=mn2;
mn2=INT_MAX;
}
}
}
int main(){
__int64 ans;
while(scanf("%d",&K) && K){
scanf("%s%s",str1,str2);
n1=strlen(str1);
n2=strlen(str2);
str1[n1++]='#';
str1[n1]='\0';
str2[n2++]='$';
str2[n2]='\0';
strcpy(str3,str1);
strcat(str3,str2);
n3=n1+n2;
make_sa();
separate();
ans=solve(h3,n3)-solve(h2,n2)-solve(h1,n1);
printf("%I64d\n",ans);
}
return 0;
}
using namespace std;
#define MAXN 200010
#define clr(x) memset(x,0,sizeof(x))
int b[MAXN],array[4][MAXN],*sa,*nsa,*rank,*nrank,
h1[MAXN],h2[MAXN],h3[MAXN],K,n1,n2,n3;
char str1[MAXN/2],str2[MAXN/2],str3[MAXN];
int st[MAXN];
void make_sa(){
int i,j,k;
sa=array[0];
nsa=array[1];
rank=array[2];
nrank=array[3];
memset(b,0,sizeof(b));
for(i=0;i<n3;i++)
b[str3[i]]++;
for(i=1;i<=256;i++)
b[i]+=b[i-1];
for(i=n3-1;i>=0;i--)
sa[--b[str3[i]]]=i;
for(rank[sa[0]]=0,i=1;i<n3;i++){
rank[sa[i]]=rank[sa[i-1]];
if(str3[sa[i]]!=str3[sa[i-1]])
rank[sa[i]]++;
}
for(k=1;k<n3 && rank[sa[n3-1]]<n3-1;k*=2){
for(i=0;i<n3;i++)
b[rank[sa[i]]]=i;
for(i=n3-1;i>=0;i--)
if(sa[i]-k>=0)
nsa[b[rank[sa[i]-k]]--]=sa[i]-k;
for(i=n3-k;i<n3;i++)
nsa[b[rank[i]]--]=i;
for(nrank[nsa[0]]=0,i=1;i<n3;i++){
nrank[nsa[i]]=nrank[nsa[i-1]];
if(rank[nsa[i]]!=rank[nsa[i-1]] || rank[nsa[i]+k]!=rank[nsa[i-1]+k])
nrank[nsa[i]]++;
}
int *t=sa;sa=nsa;nsa=t;
t=rank;rank=nrank;nrank=t;
}
for(i=0,k=0;i<n3;i++){
if(rank[i]==0)
h3[rank[i]]=0;
else{
for(j=sa[rank[i]-1];str3[i+k]==str3[j+k];k++);
h3[rank[i]]=k;
if(k>0)
k--;
}
}
}
__int64 solve(int h[],int n){
int i=0,t,top=0,m,num;
__int64 cnt=0;
st[top]=0;
h[0]=h[n]=K-1;
while(i<=n){
t=h[st[top]];//定义与height[st[top]]相关的后缀为栈顶后缀(乱编的)
if(h[i]<K && top==0)
i++;
else
if(h[i]>t)
st[++top]=i++;
else
if(h[i]==t)
i++;
else{
m=i-st[top]+1;
if(h[i]>=K && h[i]>h[st[top-1]]){
//统计与当前后缀和栈顶后缀的公共前缀不相关的
//且属于栈顶后缀的公共前缀的部分(再读一遍?)
num=t-h[i];
h[st[top]]=h[i];//下次统计当前后缀和栈顶后缀的长度为h[i]的公共前缀部分
}
else//当前后缀不能与栈顶后缀构成长度至少为K公共前缀,故暂不
//纳入统计,且看它与下一个后缀能否构成公共K前缀
num=t-h[st[--top]];
cnt+=(__int64)m*(m-1)/2*num;
}
}
return cnt;
}
void separate(){
clr(h1);
clr(h2);
int mn1=INT_MAX,mn2=INT_MAX,i,cnt1=0,cnt2=0;
for(i=0;i<n3;i++){
if(h3[i]<mn1)
mn1=h3[i];
if(h3[i]<mn2)
mn2=h3[i];
if(sa[i]<n1){
h1[cnt1++]=mn1;
mn1=INT_MAX;
}
else{
h2[cnt2++]=mn2;
mn2=INT_MAX;
}
}
}
int main(){
__int64 ans;
while(scanf("%d",&K) && K){
scanf("%s%s",str1,str2);
n1=strlen(str1);
n2=strlen(str2);
str1[n1++]='#';
str1[n1]='\0';
str2[n2++]='$';
str2[n2]='\0';
strcpy(str3,str1);
strcat(str3,str2);
n3=n1+n2;
make_sa();
separate();
ans=solve(h3,n3)-solve(h2,n2)-solve(h1,n1);
printf("%I64d\n",ans);
}
return 0;
}