bzoj 3998 (后缀自动机)
传送门:
题意:
给你一个长度为$n$的字符串$str$和一个数$K$,现在有两个询问:
1. $op=0$:不同位置的相同子串算作一个,求字典序第$K$小子串
2. $op=1$:不同位置的相同子串算作多个,求字典序第$K$小子串
题目分析:
因为后缀自动机能够包含所有的子串,因此我们考虑在后缀自动机上贪心的跳转。
我们设后缀自动机上第$i$号结点所包含的字符串的个数为$num[i]$,不难发现当前结点$i$的后继们一共会包含$\sum_{k=0}^{26}\sum_{st=i}^{next[st][k]==0} num[st]$个字符。我们设它为$siz[i]$。
当我们在后缀自动机上进行跳转时,如果当前位于的结点$i$的$num[st]+siz[i]<K$,则证明结点$i$即其后继都不能成为第$K$小,因此我们直接跳过;而倘若$num[st]+siz[z]\ge K$,则证明在以结点$i$即其后继的某个子串能够作为答案,则我们直接递归结点$i$的后继即可。
而又因为本题有两种操作,故需要分两种情况讨论。
对于$op=1$而言,因为不同位置子串算多个,则这个就等价于在每一个结点的对整体的贡献为$|endpos(i)|$
而对于$op=0$的情况,这就等价于在同一个$endpos$下,答案只能贡献$1$,即对于每一个结点$i$而言,$num[i]=1$
最后我们只需要用拓扑排序求一下$endpos$的大小,以及用后缀和维护一下$siz$,最后统计答案即可。
(ps:对于op=0的情况,这题就等价于spoj SUBLEXspoj~SUBLEXspoj SUBLEX。)
代码:
#include <bits/stdc++.h>
#define maxn 1000005
using namespace std;
char str[maxn];
struct SAM {
int next[maxn * 2][26], fa[maxn * 2], len[maxn * 2];
int last, cnt;
int cntA[maxn * 2], A[maxn * 2];
int num[maxn * 2], siz[maxn * 2];
void clear() {
last = cnt = 1;
fa[1] = len[1] = 0;
memset(next[1], 0, sizeof(next[1]));
}
void init(char *s) {
while (*s) {
Insert(*s - 'a');
s++;
}
}
void Insert(int c) {
int p = last;
int np = ++cnt;
memset(next[cnt], 0, sizeof(next[cnt]));
len[np] = len[p] + 1;
last = np;
while (p && !next[p][c]) next[p][c] = np, p = fa[p];
if (!p)
fa[np] = 1;
else {
int q = next[p][c];
if (len[q] == len[p] + 1)
fa[np] = q;
else {
int nq = ++cnt;
len[nq] = len[p] + 1;
memcpy(next[nq], next[q], sizeof(next[q]));
fa[nq] = fa[q];
fa[np] = fa[q] = nq;
while (next[p][c] == q) next[p][c] = nq, p = fa[p];
}
}
}
void build(int op) {
memset(cntA, 0, sizeof(cntA));
memset(num, 0, sizeof(num));
for (int i = 1; i <= cnt; i++) cntA[len[i]]++;
for (int i = 1; i <= cnt; i++) cntA[i] += cntA[i - 1];
for (int i = cnt; i >= 1; i--) A[cntA[len[i]]--] = i;
int tmp = 1;
int n = strlen(str);
for (int i = 0; i < n; i++) num[tmp = next[tmp][str[i] - 'a']] = 1;
for (int i = cnt; i >= 1; i--) {
int x = A[i];
if (op == 0) {
num[x] = 1;
} else {
num[fa[x]] += num[x];
}
}
num[1] = 0;
for (int i = cnt; i >= 1; i--) {
int x = A[i];
siz[x] = num[x];
for (int j = 0; j < 26; j++) {
siz[x] += siz[next[x][j]];
}
}
}
void dfs(int x, int k) {
if (k <= num[x])
return;
k -= num[x];
for (int i = 0; i < 26; i++) {
int tmp = next[x][i];
if (k <= siz[tmp]) {
putchar(i + 'a');
dfs(tmp, k);
return;
}
k -= siz[tmp];
}
}
} sam;
int main() {
scanf("%s", str);
int op, K;
scanf("%d%d", &op, &K);
sam.clear();
sam.init(str);
sam.build(op);
if (K > sam.siz[1])
puts("-1");
else
sam.dfs(1, K);
return 0;
}