poj 3415 Common Substrings 后缀数组+单调栈

题目大意

求两个串长度>=k的公共子串个数

分析

后缀数组+单调栈
考虑n^2枚举做法的优化
枚举j
再枚举\(<j的i\)
lcp(i,j)就是i-j间最小的height
贡献为height-K+1
j右移一位
左边所有\(>height[newj]\)的点的贡献就要减少了
可以发先这是height值从左往右是单调递增的
用单调栈维护
退栈时时同一height合并
对于B求一次A
对于A求一次B
求的时候就按上面的方法
不会算重算漏

注意

模板里要注意的地方都打注释了

solution

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <cmath>
#include <algorithm>
using namespace std;
const int M=400007;
typedef long long LL;

int tcas=0;
int mid;
int K,n;
char s[M];

int sa[M],t[M];
int rk[M],f[M];
int sum[M],h[M];

void getsa(){
	memset(sum,0,sizeof(sum));//
	int i,j,p,nw=500;
	for(i=1;i<=n;i++) sum[s[i]]++;
	for(i=1;i<=nw;i++) sum[i]+=sum[i-1];
	for(i=n;i>0;i--) sa[sum[s[i]]--]=i;
	for(p=0,i=1;i<=n;i++) rk[sa[i]]=(s[sa[i]]!=s[sa[i-1]])?(++p):(p);
	for(nw=p,j=1;nw!=n;j<<=1,nw=p){
		memset(sum,0,sizeof(sum));
		memcpy(f,rk,sizeof(rk));
		for(p=0,i=n-j+1;i<=n;i++) t[++p]=i;
		for(i=1;i<=n;i++) if(sa[i]>j) t[++p]=sa[i]-j;
		for(i=1;i<=n;i++) sum[f[i]]++;
		for(i=1;i<=nw;i++) sum[i]+=sum[i-1];
		for(i=n;i>0;i--) sa[sum[f[t[i]]]--]=t[i];//=t[i]
		for(p=0,i=1;i<=n;i++) rk[sa[i]]=(f[sa[i]]!=f[sa[i-1]]||f[sa[i]+j]!=f[sa[i-1]+j])?(++p):(p);
	}
}

void geth(){
	int i,j,p=0;
	for(i=1;i<=n;i++){
		j=sa[rk[i]-1];
		for(;i<=n&&j<=n&&s[i+p]==s[j+p];p++);
		h[rk[i]]=p;//rk[i]
		if(p) p--;
	}
	h[1]=0;//
}

struct node{
	LL h,num;
}st[M];
int top,num;
LL cnt;

void solve(){
	LL res=0;
	top=num=0; cnt=0;
	for(int i=1;i<=n;i++){
		if(h[i]<K){
			top=0; cnt=0;
			continue;
		}
		num=0;
		if(sa[i-1]<mid) cnt+=h[i]-K+1,num=1;
		for(;top&&st[top].h>=h[i];top--){//小心越界 
			cnt-=st[top].num*(st[top].h-h[i]);
			num+=st[top].num;
		}
		if(num){
			st[++top].h=h[i];
			st[top].num=num;
		}
		if(sa[i]>mid) res+=cnt;
	}
	for(int i=1;i<=n;i++){
		if(h[i]<K){
			top=0; cnt=0;
			continue;
		}
		num=0;
		if(sa[i-1]>mid) cnt+=h[i]-K+1,num=1;
		for(;top&&st[top].h>=h[i];top--){
			cnt-=st[top].num*(st[top].h-h[i]);
			num+=st[top].num;
		}
		if(num){
			st[++top].h=h[i];
			st[top].num=num;
		}
		if(sa[i]<mid) res+=cnt;
	}
	printf("%lld\n",res);
}

int main(){
	while(1){
		scanf("%d",&K);
		if(K==0) break;
		scanf("%s",s+1);
		n=strlen(s+1);
		s[++n]='+';
		mid=n;
		scanf("%s",s+n+1);
		n+=strlen(s+n+1);
		getsa();
		geth();
		solve();
	}
	return 0;
}
posted @ 2017-02-17 16:26  _zwl  阅读(191)  评论(0编辑  收藏  举报