[UOJ#219][BZOJ4650][Noi2016]优秀的拆分
[UOJ#219][BZOJ4650][Noi2016]优秀的拆分
试题描述
如果一个字符串可以被拆分为 AABBAABB 的形式,其中 A 和 B 是任意非空字符串,则我们称该字符串的这种拆分是优秀的。例如,对于字符串 aabaabaa,如果令 A=aab,B=a,我们就找到了这个字符串拆分成 AABBAABB 的一种方式。一个字符串可能没有优秀的拆分,也可能存在不止一种优秀的拆分。比如我们令 A=a,B=baa,也可以用 AABBAABB 表示出上述字符串;但是,字符串 abaabaa 就没有优秀的拆分。现在给出一个长度为 n 的字符串 S,我们需要求出,在它所有子串的所有拆分方式中,优秀拆分的总个数。这里的子串是指字符串中连续的一段。以下事项需要注意:出现在不同位置的相同子串,我们认为是不同的子串,它们的优秀拆分均会被记入答案。在一个拆分中,允许出现 A=B。例如 cccc 存在拆分 A=B=c。字符串本身也是它的一个子串。
输入
每个输入文件包含多组数据。输入文件的第一行只有一个整数 T,表示数据的组数。保证 1≤T≤10。接
下来 T 行,每行包含一个仅由英文小写字母构成的字符串 S,意义如题所述。
输出
输出 T 行,每行包含一个整数,表示字符串 S 所有子串的所有拆分中,总共有多少个是优秀的拆分。
输入示例
4 aabbbb cccccc aabaabaabaa bbaabaababaaba
输出示例
3 5 4 7
数据规模及约定
|S|≤30000,1≤T≤10
题解
这题貌似是少有的后缀自动机不能解决而后缀数组能解决的题目之一了。
此题做法:后缀数组 + 调和级数。
我们发现问题其实就是找到所有形如 AA 的串(即前后两半相同的串),把它所在的左、右端点位置上计数器 + 1 即可。
首先正反串都建一个后缀数组,这样方便操作。
接下来,我们不妨枚举 A 的长度 len,然后把字符串按照 len 划分成若干块,然后我们需要处理左端点在每一块中的形如 AA 的子串,假设我们当前处理的位置是 i(见下图)。
注意这张图中,我们当前位置为 i,处理开头在最左边那个块中的形如 AA 的子串。(块与块之间的分隔符是长竖线,以下“形如 AA 的子串”均简称为“AA 子串”)
令 L1 = LCP(i, i+len)(LCP 为最长公共前缀,LCS 为最长公共后缀),那么存在 AA 子串的充分必要条件是 LCS(i+len-1, i-1) > 0 且 LCS(i+len-1, i-1) + L1 >= len(否则红色区域就会有不同的字符,那么显然不可能存在 AA 子串)。那么接下来的问题就好办了,令 L2 = LCS(i+L1-len-1, i+L1-1),不难发现区间 [ i+L1-len-L2, i+L1-len ] 中的位置都是长度为 len 的 AA 子串的左端点(至于右端点在哪,请读者思考)。(然而这里并不用线段树实现区间加,可以直接打离线标记)
#include <iostream> #include <cstdio> #include <cstdlib> #include <cstring> #include <cctype> #include <algorithm> using namespace std; int read() { int x = 0, f = 1; char c = getchar(); while(!isdigit(c)){ if(c == '-') f = -1; c = getchar(); } while(isdigit(c)){ x = x * 10 + c - '0'; c = getchar(); } return x * f; } #define maxn 30010 #define maxlog 15 #define LL long long char Str[maxn]; int n, Log[maxn]; struct SA { char S[maxn]; int n, sa[maxn], Ws[maxn], rank[maxn], height[maxn], mnh[maxlog][maxn]; void init(char* str) { strcpy(S + 1, str); n = strlen(S + 1); return ; } bool cmp(int* a, int p1, int p2, int l) { if(p1 + l > n && p2 + l > n) return a[p1] == a[p2]; if(p1 + l > n || p2 + l > n) return 0; return a[p1] == a[p2] && a[p1+l] == a[p2+l]; } void ssort() { int *x = rank, *y = height, m = 0; memset(Ws, 0, sizeof(Ws)); for(int i = 1; i <= n; i++) Ws[x[i] = S[i]]++, m = max(m, x[i]); for(int i = 1; i <= m; i++) Ws[i] += Ws[i-1]; for(int i = n; i; i--) sa[Ws[x[i]]--] = i; for(int j = 1, pos; pos < n; j <<= 1, m = pos) { pos = 0; for(int i = n - j + 1; i <= n; i++) y[++pos] = i; for(int i = 1; i <= n; i++) if(sa[i] > j) y[++pos] = sa[i] - j; for(int i = 1; i <= m; i++) Ws[i] = 0; for(int i = 1; i <= n; i++) Ws[x[i]]++; for(int i = 1; i <= m; i++) Ws[i] += Ws[i-1]; for(int i = n; i; i--) sa[Ws[x[y[i]]]--] = y[i]; swap(x, y); pos = 1; x[sa[1]] = 1; for(int i = 2; i <= n; i++) x[sa[i]] = cmp(y, sa[i], sa[i-1], j) ? pos : ++pos; } return ; } void calch() { for(int i = 1; i <= n; i++) rank[sa[i]] = i; for(int i = 1, j, k = 0; i <= n; height[rank[i++]] = k) for(k ? k-- : 0, j = sa[rank[i]-1]; S[i+k] == S[j+k]; k++); return ; } void rmq_init() { Log[1] = 0; for(int i = 2; i <= n; i++) Log[i] = Log[i>>1] + 1; for(int i = 1; i <= n; i++) mnh[0][i] = height[i]; for(int j = 1; (1 << j) <= n; j++) for(int i = 1; i + (1 << j) - 1 <= n; i++) mnh[j][i] = min(mnh[j-1][i], mnh[j-1][i+(1<<j-1)]); return ; } int query(int p1, int p2) { if(p1 < 1 || p1 > n || p2 < 1 || p2 > n) return 0; int l = rank[p1], r = rank[p2]; if(l > r) swap(l, r); l++; if(l > r) return n; int t = Log[r-l+1]; return min(mnh[t][l], mnh[t][r-(1<<t)+1]); } void _debug() { for(int i = 1; i <= n; i++) printf("%d%c", sa[i], i < n ? ' ' : '\n'); for(int i = 2; i <= n; i++) printf("%d%c", height[i], i < n ? ' ' : '\n'); return ; } } sol, resol; #define repos(i) n - (i) + 1 int totl[maxn], totr[maxn]; void Addl(int l, int r) { if(l < 1) l = 1; if(r > n) r = n; if(l > r) return ; totl[l]++; totl[r+1]--; return ; } void Addr(int l, int r) { if(l < 1) l = 1; if(r > n) r = n; if(l > r) return ; totr[l]++; totr[r+1]--; return ; } int main() { int T = read(); while(T--) { scanf("%s", Str); n = strlen(Str); Str[n++] = 'A'; Str[n] = '\0'; sol.init(Str); sol.ssort(); sol.calch(); sol.rmq_init(); for(int i = 0; i < (n >> 1); i++) swap(Str[i], Str[n-i-1]); resol.init(Str); resol.ssort(); resol.calch(); resol.rmq_init(); memset(totl, 0, sizeof(totl)); memset(totr, 0, sizeof(totr)); for(int len = 1; len < n; len++) for(int i = len + 1; i + len <= n; i += len) { int l1 = sol.query(i, i + len), l2 = resol.query(repos(i + len - 1), repos(i - 1)); // printf("%d %d L1, L2: %d %d\n", len, i, l1, l2); if(!l2 || l1 + l2 < len) continue; l1 = min(l1, len - 1); int r = i + l1 - 1; l2 = min(resol.query(repos(r), repos(r - len)), l1); // printf("[%d] %d (%d, %d) %d\n", len, r, repos(r), repos(r - len), l2); // printf("addr: %d %d | %d %d\n", r + len - l2, r + len, len, i); Addr(r + len - l2, r + len); Addl(r - len - l2 + 1, r - len + 1); } for(int i = 1; i <= n; i++) totl[i] += totl[i-1], totr[i] += totr[i-1]; // for(int i = 1; i <= n; i++) printf("%d %d\n", totl[i], totr[i]); LL ans = 0; for(int i = 1; i < n; i++) ans += (LL)totr[i] * totl[i+1]; printf("%lld\n", ans); } return 0; }