P1117 [NOI2016]优秀的拆分

题意

这题正解还是有点难想的。

显然满足条件的字符串就是两个\(AA\)样子的串一前一后。

\(a_i\)表示以\(i\)为开头的\(AA\)串长度,\(a_i\)表示以\(i\)为结尾的\(AA\)串长度,那么答案显然为:
\(\sum\limits_{i=1}^{n-1}a_{i+1}*b_i\)

于是考虑怎么求这个,我们考虑枚举\(AA\)的长度\(len\),对原串每隔\(len\)设一个关键点,因为长为\(len\)\(AA\)串必定过且只过两个关键点,因此对于每对相邻的关键点,我们求出它们随对应的\(AA\)串。

求出\(lcp\)表示\([c_{i},n]\)\([c_{i+1},n]\)\(c_i\)表示第\(i\)个关键点)的最长公共前缀,\(lcs\)表示\([1,c_{i-1}-1]\)\([1,c_{i+1}-1]\)的最长公共后缀。

\(lcs+lcp<len\)时,如下图:

我们发现并不会有\(AA\)串过它们。

\(lcp+lcs\geqslant len\)时:

我们设\(t=lcp-lcs-len+1\)

我们发现在前面的\(t\)长度的点都可以作为一个\(AA\)串的开头,后面\(t\)长度的点都可以作为一个\(AA\)串的结尾。

于是我们要区间加\(1\),这个差分就好了。

code:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=30010;
int T,n;
int a[maxn],b[maxn],c[maxn],lg[maxn];
ll ans;
char s[maxn];
struct SA
{
	int n,num;
	int sa[maxn],rk[maxn],oldrk[maxn],id[maxn],tmpid[maxn],cnt[maxn];
	int height[maxn][20];
	char s[maxn];
	inline void clear()
	{
		memset(sa,0,sizeof(sa));
		memset(rk,0,sizeof(rk));
		memset(height,0x3f,sizeof(height));
		memset(s,0,sizeof(s));//一定要清空字符串。
	}
	inline bool cmp(int x,int y,int k){return oldrk[x]==oldrk[y]&&oldrk[x+k]==oldrk[y+k];}
	inline void build()
	{
		num=300;
		memset(cnt,0,sizeof(cnt));
		for(int i=1;i<=n;i++)cnt[rk[i]=s[i]]++;
		for(int i=1;i<=num;i++)cnt[i]+=cnt[i-1];
		for(int i=n;i;i--)sa[cnt[rk[i]]--]=i;
		for(int t=1;t<=n;t<<=1)
		{
			int tot=0;
			for(int i=n-t+1;i<=n;i++)id[++tot]=i;
			for(int i=1;i<=n;i++)if(sa[i]>t)id[++tot]=sa[i]-t;
			tot=0;
			memset(cnt,0,sizeof(cnt));
			for(int i=1;i<=n;i++)cnt[tmpid[i]=rk[id[i]]]++;
			for(int i=1;i<=num;i++)cnt[i]+=cnt[i-1];
			for(int i=n;i;i--)sa[cnt[tmpid[i]]--]=id[i];
			memcpy(oldrk,rk,sizeof(rk));
			for(int i=1;i<=n;i++)rk[sa[i]]=cmp(sa[i-1],sa[i],t)?tot:++tot;
			num=tot;
			if(num>=n)break;
		}
		for(int i=1,j=0;i<=n;i++)
		{
			if(j)j--;
			while(s[i+j]==s[sa[rk[i]-1]+j])j++;
			height[rk[i]][0]=j;
		}
		for(int j=1;j<=18;j++)
			for(int i=1;i+(1<<j)-1<=n;i++)
				height[i][j]=min(height[i][j-1],height[i+(1<<(j-1))][j-1]);
	}
	inline int query(int x,int y)
	{
		x=rk[x],y=rk[y];
		if(x>y)swap(x,y);x++;
		int t=lg[y-x+1];
		return min(height[x][t],height[y-(1<<t)+1][t]);
	}
}Sa[2];
inline void init()
{
	memset(a,0,sizeof(a));
	memset(b,0,sizeof(b));
	Sa[0].clear(),Sa[1].clear();
	ans=0;
}
inline void solve()
{
	scanf("%s",s+1);n=strlen(s+1);
	Sa[0].n=Sa[1].n=n;
	for(int i=1;i<=n;i++)Sa[0].s[i]=Sa[1].s[n-i+1]=s[i];
	Sa[0].build(),Sa[1].build();
	for(int len=1;len<=n/2;len++)
	{
		int tot=0;
		for(int i=len;i<=n;i+=len)c[++tot]=i;
		for(int i=1;i<tot;i++)
		{
			int lcp=min(Sa[0].query(c[i],c[i+1]),len),lcs=min(Sa[1].query(n-c[i]+2,n-c[i+1]+2),len-1);
			if(lcp+lcs<len)continue;
			int t=lcp+lcs-len+1;
			a[c[i]-lcs]++,a[c[i]-lcs+t]--;
			b[c[i+1]+lcp-t]++,b[c[i+1]+lcp]--;
		}
	}
	for(int i=1;i<=n;i++)a[i]+=a[i-1],b[i]+=b[i-1];
	for(int i=1;i<n;i++)ans+=1ll*a[i+1]*b[i];
	printf("%lld\n",ans);
}
int main()
{
	//freopen("test.in","r",stdin);
	//freopen("test.out","w",stdout);
	lg[0]=-1;
	for(int i=1;i<=30000;i++)lg[i]=lg[i>>1]+1;
	scanf("%d",&T);
	while(T--)
	{
		init();
		solve();
	}
	return 0;
}
posted @ 2019-12-23 10:29  nofind  阅读(216)  评论(0编辑  收藏  举报