bzoj4566: [Haoi2016]找相同字符

把两个字符串拼起来,然后做后缀数组和LCP,然后开两个单调栈,进行计算即可。

#include<bits/stdc++.h>
using namespace std;
long long len1,len2,n,rank[500000],sa[500000],temp[500000],str[500000],cnt[500000],p[500000];
long long q1[500000],q2[500000],height[500005],sum=0,ans1=0,ans2=0,tail1=0,tail2=0;
long long w1[500000],w2[500000],ans=0;
char s1[200010],s2[200005],s[400050];
bool equ(long long x,long long y,long long l){return rank[x]==rank[y]&&rank[x+l]==rank[y+l];}
void doubling()
{
	for(long long i=0;i<n;i++)str[i]=s[i]-'a';
	str[len1]=26;
	str[n]=rank[n]=-1;
	for(long long i=0;i<n;i++)rank[i]=str[i],sa[i]=i;
	for(long long i,l=0,pos=0,sig=26;pos<n-1;sig=pos)
	{
		for(i=n-l,pos=0;i<n;i++)p[pos++]=i;
		for(i=0;i<n;i++)if(sa[i]>=l)p[pos++]=sa[i]-l;
		memset(cnt,0,sizeof(cnt));
		for(i=0;i<n;i++)cnt[rank[p[i]]]++;
		for(i=1;i<=sig;i++)cnt[i]+=cnt[i-1];
		for(i=n-1;i>=0;i--)sa[--cnt[rank[p[i]]]]=p[i];
		for(temp[sa[0]]=pos=0,i=1;i<n;i++)
		{
			if(!equ(sa[i],sa[i-1],l))pos++;
			temp[sa[i]]=pos;
		}
		for(i=0;i<n;i++)rank[i]=temp[i];
		if(!l)l=1;else l<<=1;
	}
	long long i,k=0;
	for(i=k=0;i<n;i++)
	{
		if(k)k--;
		if(rank[i]==0)continue;
		for(long long j=sa[rank[i]-1];str[i+k]==str[j+k];)
			k++;
		height[rank[i]]=k;
	}
}
void init()
{
	scanf("%s",s1);
	scanf("%s",s2);
	len1=strlen(s1);len2=strlen(s2);
	for(long long i=0;i<len1;i++)s[i]=s1[i];
	s[len1]='$';
	for(long long i=len1+1;i<len1+len2+1;i++)s[i]=s2[i-len1-1];
	n=len1+len2+1;
}
void work()
{
	for(long long i=0;i<n;i++)
	{
		if(q1[tail1])
		{
		ans=0;
		while(q1[tail1]>=height[i]&&tail1)
		{
			ans1-=w1[tail1]*q1[tail1];
			ans+=w1[tail1];
			tail1--;
		}
		q1[++tail1]=height[i];
		w1[tail1]=ans;
		ans1+=ans*height[i];
		}
		if(q2[tail2])
		{
		ans=0;
		while(q2[tail2]>=height[i]&&tail2)
		{
			ans2-=w2[tail2]*q2[tail2];
			ans+=w2[tail2];
			tail2--;
		}
		q2[++tail2]=height[i];
		w2[tail2]=ans;
		ans2+=ans*height[i];
		}
		if(sa[i]<len1)
		{
			sum+=ans2;
			if(q1[tail1]==n-sa[i])
			{
				w1[tail1]++;
				ans1+=n-sa[i];
			}
			else 
			{
				q1[++tail1]=n-sa[i];
				w1[tail1]=1;
				ans1+=n-sa[i];
			}
		}
		else 
		{
			sum+=ans1;
			if(q2[tail2]==n-sa[i])
			{
				w2[tail2]++;
				ans2+=n-sa[i];
			}
			else 
			{
				q2[++tail2]=n-sa[i];
				w2[tail2]=1;
				ans2+=n-sa[i];
			}
		}
	}
	printf("%lld\n",sum);
}
int main()
{
	//freopen("xf.in","r",stdin);
	//freopen("xf.out","w",stdout);
	init();
	doubling();
	work();
	return 0;
}

  

posted @ 2018-03-30 21:19  mybing  阅读(172)  评论(0编辑  收藏  举报