牛客网多校训练第一场 I - Substring(后缀数组 + 重复处理)

链接:

https://www.nowcoder.com/acm/contest/139/I

 

题意:

给出一个n(1≤n≤5e4)个字符的字符串s(si ∈ {a,b,c}),
求最多可以从n*(n+1)/2个子串中选出多少个子串,使得它们互不同构。
同构是指存在一个映射f,使得字符串a的每个字符都可以映射成字符串b的对应字符。
例如ab与ac、ba、bc、ca、cb都是同构的。

 

分析:

以字符串abba为例:
现在只考虑这个字符串的2个子串ab和ba,如果不考虑重构,有2个子串,否则,只有1个子串。
这时,我们可以用全排列枚举出所有重构的字符串:
abba
acca
baab
bccb
caac
cbbc
由于每一个串都有2个子串,所以上面的6个同构串共有12个子串。
如果去掉重复的子串,则最终会剩下6个互不相同的子串。
即第一个字符串abba的ab被第三个字符串baab的ab消掉了,
第二个字符串acca的ac被第五个字符串caac的ac消掉了......
可以发现,剩下的6个子串正是ab的6种同构。
所以我们可以把一个字符串的六种同构拼接在一起,然后用后缀数组求出重复的子串个数height。
为了避免拼接的首尾字符对结果产生影响,要在拼接的每一段后面每次都加上一个新的字符。
设6个同构串的所有子串个数(6*(n*(n+1)/2))为sum。
则(sum-height)/6就是一个字符串里互不重构的子串个数。

但还有一个特殊情况:
只考虑字符串aaabbb的两个子串aaa和bbb。
如果采取上面的做法,最终会留下3个互不相同的子串aaa、bbb和ccc,即重复的子串个数为9。
这时答案是(12-9)/6=0,很显然这样是错误的。
原因是aaa的同构子串只有3种而不是6种,即单一字符的字符串的每个同构串都被多减了一次。
这时,我们可以找出一个字符串里最长的单一字符的字符串str,设它的长度为most。
因为比str短的单一字符的字符串都是str的一部分的重构,所以不需要考虑。
则正确的答案应该是(sum - height + 3*most)/6。(注意例子里aaa的长度视为1而不是3)

 

代码:

 1 #include <cstdio>
 2 #include <algorithm>
 3 using namespace std;
 4 
 5 const int MAXS = 1e6 + 5;
 6 int sa[MAXS], mem[MAXS], mem2[MAXS], amt[MAXS]; // sa:后缀数组
 7 void build_sa(char* s, int n, int m) { // n:字符串s的长度,每个字符值须小于m
 8     mem[n] = mem2[n] = -1;
 9     int i, *x = mem, *y = mem2;
10     for(i = 0; i < m; i++) amt[i] = 0;
11     for(i = 0; i < n; i++) amt[x[i]=s[i]]++;
12     for(i = 1; i < m; i++) amt[i] += amt[i-1];
13     for(i = n-1; i >= 0; i--) sa[--amt[x[i]]] = i;
14     for(int k = 1; k <= n; k <<= 1) {
15         int p = 0;
16         for(i = n-k; i < n; i++) y[p++] = i;
17         for(i = 0; i < n; i++) if(sa[i] >= k) y[p++] = sa[i]-k;
18         for(i = 0; i < m; i++) amt[i] = 0;
19         for(i = 0; i < n; i++) amt[x[y[i]]]++;
20         for(i = 1; i < m; i++) amt[i] += amt[i-1];
21         for(i = n-1; i >= 0; i--) sa[--amt[x[y[i]]]] = y[i];
22         int* t = x;  x = y;  y = t;
23         p = 1; x[sa[0]] = 0;
24         for(i = 1; i < n; i++)
25             x[sa[i]] = y[sa[i-1]]==y[sa[i]]&&y[sa[i-1]+k]==y[sa[i]+k]?p-1:p++;
26         if(p >= n) break;
27         m = p;
28     }
29 }
30 int idx[MAXS], height[MAXS]; // height:sa[i-1]与sa[i]的最长公共前缀
31 void get_height(char* s, int n) { // n:字符串s的长度
32     for(int i = 0; i < n; i++) idx[sa[i]] = i;
33     for(int k = 0, i = 0; i < n; i++) {
34         if(idx[i] - 1 < 0) continue;
35         if(k) k--;
36         int j = sa[idx[i]-1];
37         while(s[i+k] == s[j+k]) k++;
38         height[idx[i]] = k;
39     }
40 }
41 
42 char s[MAXS], os[MAXS];
43 
44 int main() {
45     int n;
46     while(~scanf("%d%s", &n, os)) {
47         int p = 0, en = 4, a[3] = {1, 2, 3};
48         do {
49             for(int i = 0; i < n; i++) s[p++] = a[os[i]-'a'];
50             s[p++] = en++;
51         } while(next_permutation(a, a+3));
52         build_sa(s, p, 10);
53         get_height(s, p);
54         long long ans = 6LL * n*(n+1)/2;
55         for(int i = 1; i < p; i++) ans -= height[i];
56         int most = 1, len = 1;
57         for(int i = 1; i <= n; i++) {
58             if(os[i] == os[i-1]) len++;
59             else most = max(most, len), len = 1;
60         }
61         printf("%lld\n", (ans + 3*most) / 6);
62     }
63     return 0;
64 }

 

posted @ 2018-08-08 23:55  Ctfes  阅读(186)  评论(0编辑  收藏  举报