后缀自动机入门题集
学了好几天后缀自动机,总算是真正搞懂了,才敢来发博客。
后缀自动机是啥以及怎么构造就不说了,毕竟有很多博客比我讲的好多了。
还是按照国际惯例,推荐几发:
hiho一下 127~132周
后缀自动机入门
史上最通俗的后缀自动机详解
先谈谈我对parent tree的理解:
首先由后缀链接构成的树就叫做parent tree(也叫后缀链接树)。
但是我们存的是反边,所以想dfs的时候就很不舒服,有两种解决办法:
1.倒着存正边(没见人用过)。
2.根据SAM的性质,子节点所代表的最长的字符串的长度一定大于父亲节点的,所以根据len的大小排序,然后从大往小更新即可。排序的时候用桶排序,可以省去一个log。
后缀链接的父亲节点的endpos完全包含子节点的endpos,且父亲节点所代表的字符串是子节点的后缀。
关于每一个节点存的endpos的数量,父亲节点的不一定恰好比子节点多1,而是子节点的endpos数量之和等于父亲节点的。
接下来是入门题:
luogu P3804 【模板】后缀自动机
确实是模板,我们建完后缀自动机后,求出每一个节点的endpos的数量,做法就是插入的同时标记一下这个节点是一个endpos。然后递归求一遍子树的endpos之和就是这个节点的endpos个数咯。之后如果siz[i] > 1,就用siz[i] * len[i]更新答案。
#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
using namespace std;
#define enter puts("")
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 1e6 + 5;
const int maxs = 30;
inline ll read()
{
ll ans = 0;
char ch = getchar(), last = ' ';
while(!isdigit(ch)) last = ch, ch = getchar();
while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
if(last == '-') ans = -ans;
return ans;
}
inline void write(ll x)
{
if(x < 0) x = -x, putchar('-');
if(x >= 10) write(x / 10);
putchar(x % 10 + '0');
}
char s[maxn];
struct Sam
{
int las, cnt;
int tra[maxn << 1][maxs], link[maxn << 1], len[maxn << 1];
int siz[maxn << 1];
In void init()
{
link[las = 0] = -1, len[cnt = 0] = 0;
}
In void insert(int c)
{
int now = ++cnt;
len[now] = len[las] + 1;
int p = las;
while(p != -1 && !tra[p][c]) tra[p][c] = now, p = link[p];
if(p == -1) link[now] = 0;
else
{
int q = tra[p][c];
if(len[q] == len[p] + 1) link[now] = q;
else
{
int clo = ++cnt;
len[clo] = len[p] + 1;
memcpy(tra[clo], tra[q], sizeof(tra[q]));
link[clo] = link[q];
link[q] = link[now] = clo;
while(p != -1 && tra[p][c] == q) tra[p][c] = clo, p = link[p];
}
}
siz[las = now] = 1;
}
int buc[maxn << 1], pos[maxn << 1];
In int dfs()
{
int ret = 0;
for(int i = 1; i <= cnt; ++i) ++buc[len[i]];
for(int i = 1; i <= cnt; ++i) buc[i] += buc[i - 1];
for(int i = 1; i <= cnt; ++i) pos[buc[len[i]]--] = i;
for(int i = cnt; i; --i)
{
int now = pos[i];
siz[link[now]] += siz[now];
if(siz[now] > 1) ret = max(ret, siz[now] * len[now]);
}
return ret;
}
}S;
int main()
{
scanf("%s", s);
int n = strlen(s); S.init();
for(int i = 0; i < n; ++i) S.insert(s[i] - 'a');
write(S.dfs()), enter;
return 0;
}
SP1811 LCS
求两个串的lcs。
把一个串建成后缀自动机,然后另一个串在上面跑,相当于枚举另一个串的前缀,看每一个前缀的后缀最多能和原串匹配多少。每成功匹配一个节点,就用当前匹配的长度更新答案。
#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<queue>
#include<assert.h>
#include<ctime>
using namespace std;
#define enter puts("")
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
#define forE(i, x, y) for(int i = head[x], y; ~i && (y = e[i].to); i = e[i].nxt)
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 2.5e5 + 5;
const int maxs = 30;
In ll read()
{
ll ans = 0;
char ch = getchar(), las = ' ';
while(!isdigit(ch)) las = ch, ch = getchar();
while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
if(las == '-') ans = -ans;
return ans;
}
In void write(ll x)
{
if(x < 0) x = -x, putchar('-');
if(x >= 10) write(x / 10);
putchar(x % 10 + '0');
}
In void MYFILE()
{
#ifndef mrclr
freopen(".in", "r", stdin);
freopen(".out", "w", stdout);
#endif
}
int n;
char s[maxn];
struct Sam
{
int tra[maxn << 1][maxs], len[maxn << 1], link[maxn << 1], cnt, las;
In void init() {link[cnt = las = 0] = -1;}
In void insert(int c)
{
int now = ++cnt, p = las; Mem(tra[now], 0);
len[now] = len[las] + 1;
while(~p && !tra[p][c]) tra[p][c] = now, p = link[p];
if(p == -1) link[now] = 0;
else
{
int q = tra[p][c];
if(len[q] == len[p] + 1) link[now] = q;
else
{
int clo = ++cnt;
len[clo] = len[p] + 1;
memcpy(tra[clo], tra[q], sizeof(tra[q]));
link[clo] = link[q], link[q] = link[now] = clo;
while(~p && tra[p][c] == q) tra[p][c] = clo, p = link[p];
}
}
las = now;
}
In int solve(char* s)
{
int ret = 0, n = strlen(s);
for(int i = 0, p = 0, l = 0; i < n; ++i)
{
while(~p && !tra[p][s[i] - 'a']) p = link[p], l = len[p];
if(p == -1) p = l = 0;
else p = tra[p][s[i] - 'a'], ++l;
ret = max(ret, l);
}
return ret;
}
}S;
int main()
{
// MYFILE();
scanf("%s", s); n = strlen(s);
S.init();
for(int i = 0; i < n; ++i) S.insert(s[i] - 'a');
scanf("%s", s);
write(S.solve(s)), enter;
return 0;
}
SP1812 LCS2
求多个串的lcs。
还是先把一个串建成后缀自动机。
然后对于每一个串,都放在后缀自动机上跑,记录在每一个节点能匹配的最大长度。然后这些长度取min,就是所有串在每一个节点能匹配的最大长度。最后答案遍历每一个节点取max即可。
不过还得想一想的是,如果这个节点成功匹配了,那么他的所有祖先节点显然也是匹配了的,但是却只标记了这个节点。所以在每一个串跑完后缀自动机后,从叶子节点把把标记在上传一遍,更新所有祖先节点的匹配情况。
#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
using namespace std;
#define enter puts("")
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 1e5 + 5;
inline ll read()
{
ll ans = 0;
char ch = getchar(), last = ' ';
while(!isdigit(ch)) last = ch, ch = getchar();
while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
if(last == '-') ans = -ans;
return ans;
}
inline void write(ll x)
{
if(x < 0) x = -x, putchar('-');
if(x >= 10) write(x / 10);
putchar(x % 10 + '0');
}
char s[maxn];
struct Sam
{
int las, cnt;
int tra[maxn << 1][30], len[maxn << 1], link[maxn << 1];
In void init() {link[las = cnt = 0] = -1;}
In void insert(int c)
{
int now = ++cnt, p = las;
len[now] = len[las] + 1;
while(~p && !tra[p][c]) tra[p][c] = now, p = link[p];
if(p == -1) link[now] = 0;
else
{
int q = tra[p][c];
if(len[q] == len[p] + 1) link[now] = q;
else
{
int clo = ++cnt;
memcpy(tra[clo], tra[q], sizeof(tra[q]));
len[clo] = len[p] + 1;
link[clo] = link[q]; link[q] = link[now] = clo;
while(~p && tra[p][c] == q) tra[p][c] = clo, p = link[p];
}
}
las = now;
}
int buc[maxn << 1], pos[maxn << 1];
In void sort()
{
for(int i = 1; i <= cnt; ++i) ++buc[len[i]];
for(int i = 1; i <= cnt; ++i) buc[i] += buc[i - 1];
for(int i = 1; i <= cnt; ++i) pos[buc[len[i]]--] = i;
}
int Max[maxn << 1], Min[maxn << 1];
In void lcs(char* s)
{
int n = strlen(s);
for(int i = 0, p = 0, l = 0; i < n; ++i)
{
int c = s[i] - 'a';
while(~p && !tra[p][c]) p = link[p], l = len[p];
if(p == -1) p = l = 0;
else ++l, p = tra[p][c], Max[p] = max(Max[p], l);
}
for(int i = cnt; i; --i)
{
int now = pos[i], fa = link[now];
Max[fa] = max(Max[fa], min(Max[now], len[fa]));
Min[now] = min(Min[now], Max[now]); Max[now] = 0;
}
}
}S;
int main()
{
//freopen("ha.in", "r", stdin);
scanf("%s", s);
int n = strlen(s); S.init();
for(int i = 0; i < n; ++i) S.insert(s[i] - 'a');
S.sort();
Mem(S.Min, 0x3f); Mem(S.Max, 0);
while(scanf("%s", s) != EOF) S.lcs(s);
int ans = 0;
for(int i = 1; i <= S.cnt; ++i) ans = max(ans, S.Min[i]);
write(ans), enter;
return 0;
}
[USACO06DEC]Milk Patterns
找出现了至少\(k\)次的最长的子串。
建完后缀自动机,求一遍子树大小,然后如果大于\(k\)就更新好了。
#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
#include<map>
using namespace std;
#define enter puts("")
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 2e4 + 5;
inline ll read()
{
ll ans = 0;
char ch = getchar(), last = ' ';
while(!isdigit(ch)) last = ch, ch = getchar();
while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
if(last == '-') ans = -ans;
return ans;
}
inline void write(ll x)
{
if(x < 0) x = -x, putchar('-');
if(x >= 10) write(x / 10);
putchar(x % 10 + '0');
}
int n, K;
struct Sam
{
int las, cnt;
map<int, int> tra[maxn << 1];
int len[maxn << 1], link[maxn << 1], siz[maxn << 1];
In void init() {link[las = cnt = 0] = -1;}
In void insert(int x)
{
int now = ++cnt, p = las;
len[now] = len[las] + 1; siz[now] = 1;
while(~p && !tra[p].count(x)) tra[p][x] = now, p = link[p];
if(p == -1) link[now] = 0;
else
{
int q = tra[p][x];
if(len[q] == len[p] + 1) link[now] = q;
else
{
int clo = ++cnt;
tra[clo] = tra[q]; len[clo] = len[p] + 1;
link[clo] = link[q];
link[q] = link[now] = clo;
while(~p && tra[p][x] == q) tra[p][x] = clo, p = link[p];
}
}
las = now;
}
int pos[maxn << 1];
In int dfs()
{
int ret = 0;
for(int i = 1; i <= cnt; ++i) pos[i] = i; //sort版,还用了lambda表达式……
sort(pos + 1, pos + cnt + 1, [=](int& a, int& b) {return len[a] > len[b];});
for(int i = 1; i <= cnt; ++i)
{
siz[link[pos[i]]] += siz[pos[i]];
if(siz[pos[i]] >= K) ret = max(ret, len[pos[i]]);
}
return ret;
}
}S;
int main()
{
n = read(), K = read();
S.init();
for(int i = 1, x; i <= n; ++i) x = read(), S.insert(x);
write(S.dfs()), enter;
return 0;
}
[TJOI2015]弦论
求第\(k\)小的子串。
因为每一个子串代表一条路径,所以我们求出从每一个节点开始有多少条路径,然后像平衡树找第\(k\)大的方法找即可。
题目还分了两种情况:\(t\)为0的话,每一个节点的endpos数量显然就是0了;\(t\)为1的话,每一个节点的endpos数量就是子树大小了。
至于怎么求从每一个点开始的路径数量,上面第二篇博客有讲。
#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
using namespace std;
#define enter puts("")
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 5e5 + 5;
inline ll read()
{
ll ans = 0;
char ch = getchar(), last = ' ';
while(!isdigit(ch)) last = ch, ch = getchar();
while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
if(last == '-') ans = -ans;
return ans;
}
inline void write(ll x)
{
if(x < 0) x = -x, putchar('-');
if(x >= 10) write(x / 10);
putchar(x % 10 + '0');
}
int Flg, K;
char s[maxn];
struct Sam
{
int las, cnt;
int tra[maxn << 1][30], len[maxn << 1], link[maxn << 1], siz[maxn << 1];
In void init() {link[las = cnt = 1] = -1;}
In void insert(int c)
{
int now = ++cnt, p = las;
len[now] = len[las] + 1; siz[now] = 1;
while(~p && !tra[p][c]) tra[p][c] = now, p = link[p];
if(p == -1) link[now] = 1;
else
{
int q = tra[p][c];
if(len[q] == len[p] + 1) link[now] = q;
else
{
int clo = ++cnt;
memcpy(tra[clo], tra[q], sizeof(tra[q]));
len[clo] = len[p] + 1;
link[clo] = link[q], link[q] = link[now] = clo;
while(~p && tra[p][c] == q) tra[p][c] = clo, p = link[p];
}
}
las = now;
}
int buc[maxn << 1], pos[maxn << 1], sum[maxn << 1];
In void dfs()
{
for(int i = 1; i <= cnt; ++i) ++buc[len[i]];
for(int i = 1; i <= cnt; ++i) buc[i] += buc[i - 1];
for(int i = 1; i <= cnt; ++i) pos[buc[len[i]]--] = i;
for(int i = cnt; i; --i) siz[link[pos[i]]] += siz[pos[i]];
for(int i = 1; i <= cnt; ++i)
{
if(!Flg) sum[i] = siz[i] = 1;
else sum[i] = siz[i];
}
siz[1] = 0;
for(int i = cnt; i; --i)
for(int j = 0; j < 26; ++j)
if(tra[pos[i]][j]) sum[pos[i]] += sum[tra[pos[i]][j]];
}
In void print(int k)
{
if(sum[1] < k) {write(-1); return;}
int now = 1;
k -= siz[now];
while(k)
{
int c = 0;
while(k > sum[tra[now][c]]) k -= sum[tra[now][c++]];
now = tra[now][c];
putchar('a' + c); k -= siz[now];
}
}
}S;
int main()
{
scanf("%s", s);
int n = strlen(s); S.init();
for(int i = 0; i < n; ++i) S.insert(s[i] - 'a');
Flg = read(), K = read();
S.dfs();
S.print(K), enter;
return 0;
}