洛谷P3975 弦论
题意:求一个串的字典序第k小的子串/本质不同第k小的子串。
解:一开始我的想法是在后缀树上找,但是不知道后缀树上的边对应的是哪些字符...
然而可以不用fail树转移,用转移边转移即可。
先建一个后缀自动机,记忆化搜索每个节点向后向后有多少个串。
然后从起点开始向后一个字符一个字符的确定。
注意每到一个新点就要判断是否结束,并把k减去在此结束的串的个数。
1 #include <cstdio> 2 #include <cstring> 3 4 const int N = 1000010; 5 6 int tr[N][26], len[N], fail[N], siz[N], bin[N], cnt[N], topo[N]; 7 int top, last; 8 9 inline void init() { 10 top = last = 1; 11 return; 12 } 13 14 inline void insert(char c) { 15 int f = c - 'a'; 16 int p = last, np = ++top; 17 last = np; 18 len[np] = len[p] + 1; 19 cnt[np] = 1; 20 while(p && !tr[p][f]) { 21 tr[p][f] = np; 22 p = fail[p]; 23 } 24 if(!p) { 25 fail[np] = 1; 26 } 27 else { 28 int Q = tr[p][f]; 29 if(len[Q] == len[p] + 1) { 30 fail[np] = Q; 31 } 32 else { 33 int nQ = ++top; 34 len[nQ] = len[p] + 1; 35 fail[nQ] = fail[Q]; 36 fail[Q] = fail[np] = nQ; 37 memcpy(tr[nQ], tr[Q], sizeof(tr[Q])); 38 while(tr[p][f] == Q) { 39 tr[p][f] = nQ; 40 p = fail[p]; 41 } 42 } 43 } 44 return; 45 } 46 47 inline void sort() { 48 for(int i = 1; i <= top; i++) { 49 bin[len[i]]++; 50 } 51 for(int i = 1; i <= top; i++) { 52 bin[i] += bin[i - 1]; 53 } 54 for(int i = 1; i <= top; i++) { 55 topo[bin[len[i]]--] = i; 56 } 57 return; 58 } 59 60 inline void count() { 61 for(int i = top; i >= 1; i--) { 62 int x = topo[i]; 63 cnt[fail[x]] += cnt[x]; 64 } 65 return; 66 } 67 68 char s[N]; 69 70 int DFS(int x) { 71 if(siz[x]) { 72 return siz[x]; 73 } 74 siz[x] = cnt[x]; 75 for(int i = 0; i < 26; i++) { 76 if(tr[x][i]) { 77 siz[x] += DFS(tr[x][i]); 78 } 79 } 80 return siz[x]; 81 } 82 83 int main() { 84 scanf("%s", s + 1); 85 int n = strlen(s + 1); 86 init(); 87 for(int i = 1; i <= n; i++) { 88 insert(s[i]); 89 } 90 sort(); 91 int flag, k; 92 scanf("%d%d", &flag, &k); 93 if(flag) { 94 count(); 95 } 96 else { 97 for(int i = 2; i <= top; i++) { 98 cnt[i] = 1; 99 } 100 } 101 102 int sum = 0; 103 for(int i = 2; i <= top; i++) { 104 sum += cnt[i] * (len[i] - len[fail[i]]); 105 } 106 107 int p = 1; 108 DFS(p); 109 if(k > sum) { 110 puts("-1"); 111 return 0; 112 } 113 cnt[1] = 0; 114 while(1) { 115 //printf("k = %d %d \n", k, k == 1); 116 if(k == 1) { 117 break; 118 } 119 else { 120 k -= cnt[p]; 121 } 122 //printf("k = %d p = %d \n", k, p); 123 for(int i = 0; i < 26; i++) { 124 if(!tr[p][i]) { 125 continue; 126 } 127 if(siz[tr[p][i]] >= k) { 128 putchar(i + 'a'); 129 p = tr[p][i]; 130 break; 131 } 132 else { 133 k -= siz[tr[p][i]]; 134 } 135 } 136 } 137 138 return 0; 139 }
又思考了一下,虽然记忆化搜索会搜到重复的节点,但是这些重复所表示的是到达它的不同方案,也就是它代表的不同子串。所以需要重复统计。
这样一个节点的一条转移边其实就是它下一个字符拼上f之后能形成的子串数。
感觉好神奇...
[update] 刚发现k=1的时候错了……………………稍微改一下就行。
1 #include <cstdio> 2 #include <cstring> 3 4 const int N = 1000010; 5 6 int tr[N][26], len[N], fail[N], siz[N], bin[N], cnt[N], topo[N]; 7 int top, last; 8 9 inline void init() { 10 top = last = 1; 11 return; 12 } 13 14 inline void insert(char c) { 15 int f = c - 'a'; 16 int p = last, np = ++top; 17 last = np; 18 len[np] = len[p] + 1; 19 cnt[np] = 1; 20 while(p && !tr[p][f]) { 21 tr[p][f] = np; 22 p = fail[p]; 23 } 24 if(!p) { 25 fail[np] = 1; 26 } 27 else { 28 int Q = tr[p][f]; 29 if(len[Q] == len[p] + 1) { 30 fail[np] = Q; 31 } 32 else { 33 int nQ = ++top; 34 len[nQ] = len[p] + 1; 35 fail[nQ] = fail[Q]; 36 fail[Q] = fail[np] = nQ; 37 memcpy(tr[nQ], tr[Q], sizeof(tr[Q])); 38 while(tr[p][f] == Q) { 39 tr[p][f] = nQ; 40 p = fail[p]; 41 } 42 } 43 } 44 return; 45 } 46 47 inline void sort() { 48 for(int i = 1; i <= top; i++) { 49 bin[len[i]]++; 50 } 51 for(int i = 1; i <= top; i++) { 52 bin[i] += bin[i - 1]; 53 } 54 for(int i = 1; i <= top; i++) { 55 topo[bin[len[i]]--] = i; 56 } 57 return; 58 } 59 60 inline void count() { 61 for(int i = top; i >= 1; i--) { 62 int x = topo[i]; 63 cnt[fail[x]] += cnt[x]; 64 } 65 return; 66 } 67 68 char s[N]; 69 70 int DFS(int x) { 71 if(siz[x]) { 72 return siz[x]; 73 } 74 siz[x] = cnt[x]; 75 for(int i = 0; i < 26; i++) { 76 if(tr[x][i]) { 77 siz[x] += DFS(tr[x][i]); 78 } 79 } 80 return siz[x]; 81 } 82 83 int main() { 84 scanf("%s", s + 1); 85 int n = strlen(s + 1); 86 init(); 87 for(int i = 1; i <= n; i++) { 88 insert(s[i]); 89 } 90 sort(); 91 int flag, k; 92 scanf("%d%d", &flag, &k); 93 if(flag) { 94 count(); 95 } 96 else { 97 for(int i = 2; i <= top; i++) { 98 cnt[i] = 1; 99 } 100 } 101 102 int sum = 0; 103 for(int i = 2; i <= top; i++) { 104 sum += cnt[i] * (len[i] - len[fail[i]]); 105 } 106 107 int p = 1; 108 DFS(p); 109 if(k > sum) { 110 puts("-1"); 111 return 0; 112 } 113 ++k; 114 cnt[1] = 1; 115 while(1) { 116 //printf("k = %d %d \n", k, k == 1); 117 if(k == 1) { 118 break; 119 } 120 else { 121 k -= cnt[p]; 122 } 123 //printf("k = %d p = %d \n", k, p); 124 for(int i = 0; i < 26; i++) { 125 if(!tr[p][i]) { 126 continue; 127 } 128 if(siz[tr[p][i]] >= k) { 129 putchar(i + 'a'); 130 p = tr[p][i]; 131 break; 132 } 133 else { 134 k -= siz[tr[p][i]]; 135 } 136 } 137 } 138 139 return 0; 140 }