poj 3415 Common Substrings【SA+单调栈】

把两个串中间加一个未出现字符接起来,然后求SA
然后把贡献统计分为两部分,在排序后的后缀里,属于串2的后缀和排在他前面属于串1的后缀的贡献和属于串1的后缀和排在他前面属于串2的后缀的贡献
两部分分别作,都用单调栈维护一段里的height最小值(因为lcp是排序后两后缀中间height最小值),然后根据每次入栈种类来给答案算贡献

#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
const int N=500005;
int n,m,k,wa[N],wb[N],wv[N],wsu[N],sa[N],rk[N],he[N],s[N],v[N],top;
long long tot,ans;
char c[N],t[N];
bool cmp(int r[],int a,int b,int l)
{
	return r[a]==r[b]&&r[a+l]==r[b+l];
}
void saa(char r[],int n,int m)
{
	int *x=wa,*y=wb;
	for(int i=0;i<=m;i++)
		wsu[i]=0;
	for(int i=1;i<=n;i++)
		wsu[x[i]=r[i]]++;
	for(int i=1;i<=m;i++)
		wsu[i]+=wsu[i-1];
	for(int i=n;i>=1;i--)
		sa[wsu[x[i]]--]=i;
	for(int j=1,p=1;j<n&&p<n;j<<=1,m=p)
	{
		p=0;
		for(int i=n-j+1;i<=n;i++)
			y[++p]=i;
		for(int i=1;i<=n;i++)
			if(sa[i]>j)
				y[++p]=sa[i]-j;
		for(int i=1;i<=n;i++)
			wv[i]=x[y[i]];
		for(int i=0;i<=m;i++)
			wsu[i]=0;
		for(int i=1;i<=n;i++)
			wsu[wv[i]]++;
		for(int i=1;i<=m;i++)
			wsu[i]+=wsu[i-1];
		for(int i=n;i>=1;i--)
			sa[wsu[wv[i]]--]=y[i];
		swap(x,y);
		p=1;
		x[sa[1]]=1;
		for(int i=2;i<=n;i++)
			x[sa[i]]=cmp(y,sa[i-1],sa[i],j)?p:++p;
	}
	for(int i=1;i<=n;i++)
		rk[sa[i]]=i;
	for(int i=1,j,k=0;i<=n;he[rk[i++]]=k)
		for(k?k--:0,j=sa[rk[i]-1];r[i+k]==r[j+k];k++);
}
int main()
{
	while(scanf("%d",&k)&&k)
	{
		scanf("%s%s",c+1,t+1);
		n=strlen(c+1),m=strlen(t+1);
		c[n+1]=1;
		for(int i=n+2;i<=n+m+1;i++)
			c[i]=t[i-n-1];
		saa(c,n+m+1,200);
		sa[0]=n+m+2;
		// for(int i=0;i<=n+m+2;i++)
			// cerr<<sa[i]-1<<" ";
		// cerr<<endl;
		ans=0,top=0,tot=0;
		for(int i=1;i<=n+m+1;i++)
		{
			if(he[i]<k)
				top=0,tot=0;
			else
			{
				int con=0;
				if(sa[i-1]<=n)
					con++,tot+=he[i]-k+1;
				while(top>0&&he[i]<=s[top])
					tot-=v[top]*(s[top]-he[i]),con+=v[top--];
				s[++top]=he[i],v[top]=con;
				if(sa[i]>n+1)
					ans+=tot;//cerr<<tot<<endl;
			}
		}//cerr<<ans<<endl;
		top=0,tot=0;
		for(int i=1;i<=n+m+1;i++)
		{
			if(he[i]<k)
				top=0,tot=0;
			else
			{
				int con=0;
				if(sa[i-1]>n+1)
					con++,tot+=he[i]-k+1;
				while(top>0&&he[i]<=s[top])
					tot-=v[top]*(s[top]-he[i]),con+=v[top--];
				s[++top]=he[i],v[top]=con;
				if(sa[i]<=n)
					ans+=tot;
			}
		}
		printf("%lld\n",ans);
	}
	return 0;
}
posted @ 2018-11-21 09:01  lokiii  阅读(97)  评论(0编辑  收藏  举报