牛客网多校训练第一场 I - Substring(后缀数组 + 重复处理)
链接:
https://www.nowcoder.com/acm/contest/139/I
题意:
给出一个n(1≤n≤5e4)个字符的字符串s(si ∈ {a,b,c}),
求最多可以从n*(n+1)/2个子串中选出多少个子串,使得它们互不同构。
同构是指存在一个映射f,使得字符串a的每个字符都可以映射成字符串b的对应字符。
例如ab与ac、ba、bc、ca、cb都是同构的。
分析:
以字符串abba为例:
现在只考虑这个字符串的2个子串ab和ba,如果不考虑重构,有2个子串,否则,只有1个子串。
这时,我们可以用全排列枚举出所有重构的字符串:
abba
acca
baab
bccb
caac
cbbc
由于每一个串都有2个子串,所以上面的6个同构串共有12个子串。
如果去掉重复的子串,则最终会剩下6个互不相同的子串。
即第一个字符串abba的ab被第三个字符串baab的ab消掉了,
第二个字符串acca的ac被第五个字符串caac的ac消掉了......
可以发现,剩下的6个子串正是ab的6种同构。
所以我们可以把一个字符串的六种同构拼接在一起,然后用后缀数组求出重复的子串个数height。
为了避免拼接的首尾字符对结果产生影响,要在拼接的每一段后面每次都加上一个新的字符。
设6个同构串的所有子串个数(6*(n*(n+1)/2))为sum。
则(sum-height)/6就是一个字符串里互不重构的子串个数。
但还有一个特殊情况:
只考虑字符串aaabbb的两个子串aaa和bbb。
如果采取上面的做法,最终会留下3个互不相同的子串aaa、bbb和ccc,即重复的子串个数为9。
这时答案是(12-9)/6=0,很显然这样是错误的。
原因是aaa的同构子串只有3种而不是6种,即单一字符的字符串的每个同构串都被多减了一次。
这时,我们可以找出一个字符串里最长的单一字符的字符串str,设它的长度为most。
因为比str短的单一字符的字符串都是str的一部分的重构,所以不需要考虑。
则正确的答案应该是(sum - height + 3*most)/6。(注意例子里aaa的长度视为1而不是3)
代码:
1 #include <cstdio> 2 #include <algorithm> 3 using namespace std; 4 5 const int MAXS = 1e6 + 5; 6 int sa[MAXS], mem[MAXS], mem2[MAXS], amt[MAXS]; // sa:后缀数组 7 void build_sa(char* s, int n, int m) { // n:字符串s的长度,每个字符值须小于m 8 mem[n] = mem2[n] = -1; 9 int i, *x = mem, *y = mem2; 10 for(i = 0; i < m; i++) amt[i] = 0; 11 for(i = 0; i < n; i++) amt[x[i]=s[i]]++; 12 for(i = 1; i < m; i++) amt[i] += amt[i-1]; 13 for(i = n-1; i >= 0; i--) sa[--amt[x[i]]] = i; 14 for(int k = 1; k <= n; k <<= 1) { 15 int p = 0; 16 for(i = n-k; i < n; i++) y[p++] = i; 17 for(i = 0; i < n; i++) if(sa[i] >= k) y[p++] = sa[i]-k; 18 for(i = 0; i < m; i++) amt[i] = 0; 19 for(i = 0; i < n; i++) amt[x[y[i]]]++; 20 for(i = 1; i < m; i++) amt[i] += amt[i-1]; 21 for(i = n-1; i >= 0; i--) sa[--amt[x[y[i]]]] = y[i]; 22 int* t = x; x = y; y = t; 23 p = 1; x[sa[0]] = 0; 24 for(i = 1; i < n; i++) 25 x[sa[i]] = y[sa[i-1]]==y[sa[i]]&&y[sa[i-1]+k]==y[sa[i]+k]?p-1:p++; 26 if(p >= n) break; 27 m = p; 28 } 29 } 30 int idx[MAXS], height[MAXS]; // height:sa[i-1]与sa[i]的最长公共前缀 31 void get_height(char* s, int n) { // n:字符串s的长度 32 for(int i = 0; i < n; i++) idx[sa[i]] = i; 33 for(int k = 0, i = 0; i < n; i++) { 34 if(idx[i] - 1 < 0) continue; 35 if(k) k--; 36 int j = sa[idx[i]-1]; 37 while(s[i+k] == s[j+k]) k++; 38 height[idx[i]] = k; 39 } 40 } 41 42 char s[MAXS], os[MAXS]; 43 44 int main() { 45 int n; 46 while(~scanf("%d%s", &n, os)) { 47 int p = 0, en = 4, a[3] = {1, 2, 3}; 48 do { 49 for(int i = 0; i < n; i++) s[p++] = a[os[i]-'a']; 50 s[p++] = en++; 51 } while(next_permutation(a, a+3)); 52 build_sa(s, p, 10); 53 get_height(s, p); 54 long long ans = 6LL * n*(n+1)/2; 55 for(int i = 1; i < p; i++) ans -= height[i]; 56 int most = 1, len = 1; 57 for(int i = 1; i <= n; i++) { 58 if(os[i] == os[i-1]) len++; 59 else most = max(most, len), len = 1; 60 } 61 printf("%lld\n", (ans + 3*most) / 6); 62 } 63 return 0; 64 }