poj 3415 Common Substrings 后缀数组+单调栈
题目大意
求两个串长度>=k的公共子串个数
分析
后缀数组+单调栈
考虑n^2枚举做法的优化
枚举j
再枚举\(<j的i\)
lcp(i,j)就是i-j间最小的height
贡献为height-K+1
j右移一位
左边所有\(>height[newj]\)的点的贡献就要减少了
可以发先这是height值从左往右是单调递增的
用单调栈维护
退栈时时同一height合并
对于B求一次A
对于A求一次B
求的时候就按上面的方法
不会算重算漏
注意
模板里要注意的地方都打注释了
solution
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <cmath>
#include <algorithm>
using namespace std;
const int M=400007;
typedef long long LL;
int tcas=0;
int mid;
int K,n;
char s[M];
int sa[M],t[M];
int rk[M],f[M];
int sum[M],h[M];
void getsa(){
memset(sum,0,sizeof(sum));//
int i,j,p,nw=500;
for(i=1;i<=n;i++) sum[s[i]]++;
for(i=1;i<=nw;i++) sum[i]+=sum[i-1];
for(i=n;i>0;i--) sa[sum[s[i]]--]=i;
for(p=0,i=1;i<=n;i++) rk[sa[i]]=(s[sa[i]]!=s[sa[i-1]])?(++p):(p);
for(nw=p,j=1;nw!=n;j<<=1,nw=p){
memset(sum,0,sizeof(sum));
memcpy(f,rk,sizeof(rk));
for(p=0,i=n-j+1;i<=n;i++) t[++p]=i;
for(i=1;i<=n;i++) if(sa[i]>j) t[++p]=sa[i]-j;
for(i=1;i<=n;i++) sum[f[i]]++;
for(i=1;i<=nw;i++) sum[i]+=sum[i-1];
for(i=n;i>0;i--) sa[sum[f[t[i]]]--]=t[i];//=t[i]
for(p=0,i=1;i<=n;i++) rk[sa[i]]=(f[sa[i]]!=f[sa[i-1]]||f[sa[i]+j]!=f[sa[i-1]+j])?(++p):(p);
}
}
void geth(){
int i,j,p=0;
for(i=1;i<=n;i++){
j=sa[rk[i]-1];
for(;i<=n&&j<=n&&s[i+p]==s[j+p];p++);
h[rk[i]]=p;//rk[i]
if(p) p--;
}
h[1]=0;//
}
struct node{
LL h,num;
}st[M];
int top,num;
LL cnt;
void solve(){
LL res=0;
top=num=0; cnt=0;
for(int i=1;i<=n;i++){
if(h[i]<K){
top=0; cnt=0;
continue;
}
num=0;
if(sa[i-1]<mid) cnt+=h[i]-K+1,num=1;
for(;top&&st[top].h>=h[i];top--){//小心越界
cnt-=st[top].num*(st[top].h-h[i]);
num+=st[top].num;
}
if(num){
st[++top].h=h[i];
st[top].num=num;
}
if(sa[i]>mid) res+=cnt;
}
for(int i=1;i<=n;i++){
if(h[i]<K){
top=0; cnt=0;
continue;
}
num=0;
if(sa[i-1]>mid) cnt+=h[i]-K+1,num=1;
for(;top&&st[top].h>=h[i];top--){
cnt-=st[top].num*(st[top].h-h[i]);
num+=st[top].num;
}
if(num){
st[++top].h=h[i];
st[top].num=num;
}
if(sa[i]<mid) res+=cnt;
}
printf("%lld\n",res);
}
int main(){
while(1){
scanf("%d",&K);
if(K==0) break;
scanf("%s",s+1);
n=strlen(s+1);
s[++n]='+';
mid=n;
scanf("%s",s+n+1);
n+=strlen(s+n+1);
getsa();
geth();
solve();
}
return 0;
}