[NOI2016][洛谷P1117]优秀的拆分(SA)

题面

https://www.luogu.com.cn/problem/P1117

题解

前置知识:

本题要求一个字符串中所有AABB形式的字符串(可重)的个数。

首先考虑简化要求:设f[x]表示以第x位为结尾,有多少个AA形式的字符串;g[x]表示以第x位为开头有多少个AA形式的字符串。答案显然是\(\sum f[i]g[i+1]\)

枚举AA型字符串的半长len,然后设置第1位,第len+1位,第2len+1位…为特殊点。一个长度为2len的AA型字符串一定通过恰好两个相邻的特殊点。不妨设这两个点是i,j。

A在特殊点左边的部分长l(包括特殊点本身),那么显然有\(1{\leq}l{\leq}len\)。另外,i,j还必须满足\(lcs(pre_i,pre_j){\geq}l\)以及\(lcp(suf_i,suf_j){\geq}len-l+1\)

所以通过两个相邻特殊点i、j,并且特殊点左边的部分长为l的、半长为len的AA型字符串存在的必要条件是:

\[\begin{cases} l{\geq}\max(1,len+1-lcp(suf_i,suf_j)) \\ l{\leq}\min(len,lcs(pre_i,pre_j)) \end{cases} \]

不难发现这也是充分条件。

所以枚举了len,i,j之后,设\(high=\min(len,lcs(pre_i,pre_j)),low=\max(1,len+1-lcp(suf_i,suf_j))\),如果\(high{\leq}low\),就把i-high+1到i-low+1的g值全部++,把j+len-high到j+len-low的f值全部++。这个可以维护差分而做到\(O(1)\)的更新。

前缀的最长公共后缀、后缀的最长公共前缀都可以通过预处理前(后)缀数组+height数组上ST表做到O(1)。

所以总时间复杂度是调和级数\(O(\sum_{i=1}^{n}{\frac{n}{i}})=O(n \log n)\)

代码

#include<bits/stdc++.h>

using namespace std;

#define rg register
#define In inline
#define ll long long

const int N = 30000;

In int read(){
	int s = 0,ww = 1;
	char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-')ww = -1;ch = getchar();}
	while('0' <= ch && ch <= '9'){s = 10 * s + ch - '0';ch = getchar();}
	return s * ww;
}

int n;
char s[N+5];
ll f[N+5],g[N+5];
int lg[N+5];

struct ST{
	int minn[N+5][16];
	void prepro(int a[]){
		for(rg int i = 1;i <= n;i++)minn[i][0] = a[i];
		for(rg int j = 1;j <= 15;j++)
			for(rg int i = 1;i + (1<<j) - 1 <= n;i++)minn[i][j] = min(minn[i][j-1],minn[i+(1<<(j-1))][j-1]);
	}
	int query(int l,int r){
		int d = lg[r-l+1];
		return min(minn[l][d],minn[r+1-(1<<d)][d]);
	}
};

struct SA{
	int sa[N+5],rk[N+5],temp[N+5],num[N+5],h[N+5];
	int m;	
	void clear(){
		memset(sa,0,sizeof(int)*(n+2));
		memset(rk,0,sizeof(int)*(n+2));
		memset(temp,0,sizeof(int)*(n+2));
	}
	void qsort(){
		memset(num,0,sizeof(int) * (m+1));
		for(rg int i = 1;i <= n;i++)num[rk[i]]++;
		for(rg int i = 2;i <= m;i++)num[i] += num[i-1];
		for(rg int i = n;i >= 1;i--)sa[num[rk[temp[i]]]--] = temp[i];
	}
	ST H;
	void calch(){
		int k = 0;
		for(rg int i = 1;i <= n;i++){
			if(rk[i] == 1)h[1] = k = 0;
			else{
				if(k)k--;
				int j = sa[rk[i]-1];
				while(s[i+k] == s[j+k])k++;
				h[rk[i]] = k;
			}
		}
	}
	void init(){
		clear();
		m = 26;
		for(rg int i = 1;i <= n;i++)temp[i] = i;
		for(rg int i = 1;i <= n;i++)rk[i] = s[i] - 'a' + 1;
		qsort();
		for(rg int d = 1;d <= n;d <<= 1){
			int cnt = 0;
			for(rg int i = n - d + 1;i <= n;i++)temp[++cnt] = i;
			for(rg int i = 1;i <= n;i++)if(sa[i] > d)temp[++cnt] = sa[i] - d;
			qsort();
			memcpy(temp,rk,sizeof(int) * (n+1));
			cnt = 1;
			rk[sa[1]] = 1;
			for(rg int i = 2;i <= n;i++){
				if(temp[sa[i]] != temp[sa[i-1]] || temp[sa[i]+d] != temp[sa[i-1]+d])cnt++;
				rk[sa[i]] = cnt;
			}
			if(cnt == n)break;
			m = cnt;
		}
		calch();
		H.prepro(h);
	}
	int lcp(int i,int j){
		int x = rk[i],y = rk[j];
		if(x > y)swap(x,y);
		return H.query(x + 1,y);
	}
}S;

struct PA{
	int pa[N+5],rk[N+5],temp[N+5],num[N+5],h[N+5];
	int m;
	void clear(){
		memset(pa,0,sizeof(int)*(n+2));
		memset(rk,0,sizeof(int)*(n+2));
		memset(temp,0,sizeof(int)*(n+2));
	}
	void qsort(){
		memset(num,0,sizeof(int) * (m+1));
		for(rg int i = 1;i <= n;i++)num[rk[i]]++;
		for(rg int i = 2;i <= m;i++)num[i] += num[i-1];
		for(rg int i = n;i >= 1;i--)pa[num[rk[temp[i]]]--] = temp[i];
	}
	ST H;
	void calch(){
		int k = 0;
		for(rg int i = n;i >= 1;i--){
			if(rk[i] == 1)h[1] = k = 0;
			else{
				if(k)k--;
				int j = pa[rk[i]-1];
				while(s[i-k] == s[j-k])k++;
				h[rk[i]] = k;
			}
		}
	}
	void init(){
		clear();
		m = 26;
		for(rg int i = 1;i <= n;i++)temp[i] = i;
		for(rg int i = 1;i <= n;i++)rk[i] = s[i] - 'a' + 1;
		qsort();
		for(rg int d = 1;d <= n;d <<= 1){
			int cnt = 0;
			for(rg int i = 1;i <= d;i++)temp[++cnt] = i;
			for(rg int i = 1;i <= n;i++)if(pa[i] + d <= n)temp[++cnt] = pa[i] + d;
			qsort();
			memcpy(temp,rk,sizeof(int) * (n+1));
			cnt = 1;
			rk[pa[1]] = 1;
			for(rg int i = 2;i <= n;i++){
				if(temp[pa[i]] != temp[pa[i-1]] || temp[pa[i]-d] != temp[pa[i-1]-d])cnt++;
				rk[pa[i]] = cnt;
			}
			if(cnt == n)break;
			m = cnt;
		}
		calch();
		H.prepro(h);
	}
	int lcs(int i,int j){
		int x = rk[i],y = rk[j];
		if(x > y)swap(x,y);
		return H.query(x + 1,y);
	}
}P;

void calcfg(){
	for(rg int len = 1;(len<<1) <= n;len++){
		for(rg int i = 1;i + len <= n;i += len){
			int j = i + len;
			int high = P.lcs(i,j); high = min(high,len);
			int low = S.lcp(i,j); low = max(len + 1 - low,1);
			if(low <= high){
				g[i-high+1]++;
				g[i-low+2]--;
				f[j+len-high]++;
				f[j+len-low+1]--;
			}
		}
	}
	for(rg int i = 1;i <= n;i++)f[i] += f[i-1],g[i] += g[i-1];
}

int main(){
	for(rg int i = 2;i <= N;i++)lg[i] = lg[i>>1] + 1;
	int T = read();
	while(T--){
		scanf("%s",s + 1);
		n = strlen(s + 1);
		S.init();
		P.init();
		calcfg();
		ll ans = 0;
		for(rg int i = 1;i < n;i++)ans += f[i] * g[i+1];
		cout << ans << endl;	
		memset(f,0,sizeof(ll) * (n+2));
		memset(g,0,sizeof(ll) * (n+2));		
	}
	return 0;
}
posted @ 2020-10-05 19:00  coder66  阅读(215)  评论(0编辑  收藏  举报