Common Substrings POJ - 3415(后缀数组+单调栈)
Common Substrings POJ - 3415(后缀数组+单调栈)
题意:给一个k,两个字符串,求两串各选一个子串中长度大于k且相等的对数;
题解:将两字符串间加一个‘#’连接,利用后缀数组求出最长公共前缀,利用rmq处理,枚举所有后缀对,其答案就是每对的共享就是(最长公共前缀-k+1),但是枚举要n* n * log(n)的时间,所以我们这里要利用单调栈优化,如后缀排序后的1串与2串的lcp是3,2串与3串的lcp是1,那么1串与3串的lcp是1,具有单调性,利用单调栈统计第i个后缀前的lcp的种类数与对应个数既可代替rmq计算出答案,因为这里给了两个字符串,要求1串找2串,2串找1串,故需要两个单调栈分别计算贡献。
代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<stack>
using namespace std;
typedef long long ll;
const int N=200010;
int wa[N],wb[N],wv[N],wss[N],cal[N];
int rak[N],height[N],sa[N];
int n;
char s[N],s2[N];
int cmp(int *r,int a,int b,int l){
return r[a]==r[b]&&r[a+l]==r[b+l];
}
void da(int *r,int *sa,int n,int M) {
int i,j,p,*x=wa,*y=wb,*t;
for(i=0;i<M;i++) wss[i]=0;
for(i=0;i<n;i++) wss[x[i]=r[i]]++;
for(i=1;i<M;i++) wss[i]+=wss[i-1];
for(i=n-1;i>=0;i--) sa[--wss[x[i]]]=i;
for(j=1,p=1;p<n;j*=2,M=p) {
for(p=0,i=n-j;i<n;i++) y[p++]=i;
for(i=0;i<n;i++) if(sa[i]>=j) y[p++]=sa[i]-j;
for(i=0;i<n;i++) wv[i]=x[y[i]];
for(i=0;i<M;i++) wss[i]=0;
for(i=0;i<n;i++) wss[wv[i]]++;
for(i=1;i<M;i++)wss[i]+=wss[i-1];
for(i=n-1;i>=0;i--) sa[--wss[wv[i]]]=y[i];
for(t=x,x=y,y=t,p=1,x[sa[0]]=0,i=1;i<n;i++)
x[sa[i]]=cmp(y,sa[i-1],sa[i],j)?p-1:p++;
}
return;
}
void calheight(int *r,int *sa,int n) {
int i,j,k=0;
for(i=1;i<=n;i++) rak[sa[i]]=i;
for(i=0;i<n;height[rak[i++]]=k)
for(k?k--:0,j=sa[rak[i]-1];r[i+k]==r[j+k];k++);
for(int i=n;i;i--)rak[i]=rak[i-1],sa[i]++;
}
stack<pair<int,int> >ma;
stack<pair<int,int> >mb;
pair<int,int>lin;
int k,m;
int main(){
int cas=1;
while(~scanf("%d",&k)){
if(k==0)break;
scanf("%s%s",s+1,s2+1);
while(!ma.empty())ma.pop();
while(!mb.empty())mb.pop();
m=strlen(s+1);
strcat(s+1,"#");
strcat(s+1,s2+1);
n=strlen(s+1);
for(int i=1;i<=n;i++)
cal[i]=s[i];
cal[n+1]=0;
da(cal+1,sa,n+1,200);
calheight(cal+1,sa,n);
for(int i=1;i<=n;i++){
height[i]=max(0,height[i]-k+1);
}
long long cnt1=0,cnt2=0,ans=0;
for(int i=2;i<=n;i++){
int sum=0;
if(sa[i]<=m){
sum=0;
while(!ma.empty()){
lin=ma.top();
if(height[i]<=lin.first){
sum+=lin.second;
cnt1-=lin.first*lin.second;
ma.pop();
}
else break;
}
ma.push({height[i],sum});
cnt1+=height[i]*sum;
ans+=cnt1;
sum=0;
while(!mb.empty()){
lin=mb.top();
if(height[i]<lin.first){
sum+=lin.second;
cnt2-=lin.first*lin.second;
mb.pop();
}
else break;
}
mb.push({height[i],sum});
cnt2+=height[i]*sum;
if(i+1<=n){
sum=1;
while(!mb.empty()){
lin=mb.top();
if(height[i+1]<=lin.first){
sum+=lin.second;
cnt2-=lin.first*lin.second;
mb.pop();
}
else break;
}
mb.push({height[i+1],sum});
cnt2+=height[i+1]*sum;
}
}
else{
sum=0;
while(!mb.empty()){
lin=mb.top();
if(height[i]<lin.first){
sum+=lin.second;
cnt2-=lin.first*lin.second;
mb.pop();
}
else break;
}
mb.push({height[i],sum});
cnt2+=height[i]*sum;
ans+=cnt2;
sum=0;
while(!ma.empty()){
lin=ma.top();
if(height[i]<=lin.first){
sum+=lin.second;
cnt1-=lin.first*lin.second;
ma.pop();
}
else break;
}
ma.push({height[i],sum});
cnt1+=height[i]*sum;
if(i+1<=n){
sum=1;
while(!ma.empty()){
lin=ma.top();
if(height[i+1]<lin.first){
sum+=lin.second;
cnt1-=lin.first*lin.second;
ma.pop();
}
else break;
}
ma.push({height[i+1],sum});
cnt1+=height[i+1]*sum;
}
}
}
printf("%lld\n",ans);
}
}