后缀自动机入门题集

学了好几天后缀自动机,总算是真正搞懂了,才敢来发博客。


后缀自动机是啥以及怎么构造就不说了,毕竟有很多博客比我讲的好多了。
还是按照国际惯例,推荐几发:
hiho一下 127~132周
后缀自动机入门
史上最通俗的后缀自动机详解


先谈谈我对parent tree的理解:
首先由后缀链接构成的树就叫做parent tree(也叫后缀链接树)。
但是我们存的是反边,所以想dfs的时候就很不舒服,有两种解决办法:
1.倒着存正边(没见人用过)。
2.根据SAM的性质,子节点所代表的最长的字符串的长度一定大于父亲节点的,所以根据len的大小排序,然后从大往小更新即可。排序的时候用桶排序,可以省去一个log。
后缀链接的父亲节点的endpos完全包含子节点的endpos,且父亲节点所代表的字符串是子节点的后缀。
关于每一个节点存的endpos的数量,父亲节点的不一定恰好比子节点多1,而是子节点的endpos数量之和等于父亲节点的。


接下来是入门题:
luogu P3804 【模板】后缀自动机
确实是模板,我们建完后缀自动机后,求出每一个节点的endpos的数量,做法就是插入的同时标记一下这个节点是一个endpos。然后递归求一遍子树的endpos之和就是这个节点的endpos个数咯。之后如果siz[i] > 1,就用siz[i] * len[i]更新答案。

#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
using namespace std;
#define enter puts("") 
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 1e6 + 5;
const int maxs = 30;
inline ll read()
{
  ll ans = 0;
  char ch = getchar(), last = ' ';
  while(!isdigit(ch)) last = ch, ch = getchar();
  while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
  if(last == '-') ans = -ans;
  return ans;
}
inline void write(ll x)
{
  if(x < 0) x = -x, putchar('-');
  if(x >= 10) write(x / 10);
  putchar(x % 10 + '0');
}

char s[maxn];
struct Sam
{
  int las, cnt;
  int tra[maxn << 1][maxs], link[maxn << 1], len[maxn << 1];
  int siz[maxn << 1];
  In void init()
  {
    link[las = 0] = -1, len[cnt = 0] = 0;
  }
  In void insert(int c)
  {
    int now = ++cnt;
    len[now] = len[las] + 1;
    int p = las;
    while(p != -1 && !tra[p][c]) tra[p][c] = now, p = link[p];
    if(p == -1) link[now] = 0;
    else
      {
    int q = tra[p][c];
    if(len[q] == len[p] + 1) link[now] = q;
    else
      {
        int clo = ++cnt;
        len[clo] = len[p] + 1;
        memcpy(tra[clo], tra[q], sizeof(tra[q]));
        link[clo] = link[q];
        link[q] = link[now] = clo;
        while(p != -1 && tra[p][c] == q) tra[p][c] = clo, p = link[p];
      }
      }
    siz[las = now] = 1;
  }
  int buc[maxn << 1], pos[maxn << 1];
  In int dfs()
  {
    int ret = 0;
    for(int i = 1; i <= cnt; ++i) ++buc[len[i]];
    for(int i = 1; i <= cnt; ++i) buc[i] += buc[i - 1];
    for(int i = 1; i <= cnt; ++i) pos[buc[len[i]]--] = i;
    for(int i = cnt; i; --i)
      {
    int now = pos[i];
    siz[link[now]] += siz[now];
    if(siz[now] > 1) ret = max(ret, siz[now] * len[now]);
      }
    return ret;
  }
}S;

int main()
{
  scanf("%s", s);
  int n = strlen(s); S.init();
  for(int i = 0; i < n; ++i) S.insert(s[i] - 'a');
  write(S.dfs()), enter;
  return 0;
}

SP1811 LCS
求两个串的lcs。
把一个串建成后缀自动机,然后另一个串在上面跑,相当于枚举另一个串的前缀,看每一个前缀的后缀最多能和原串匹配多少。每成功匹配一个节点,就用当前匹配的长度更新答案。

#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<queue>
#include<assert.h>
#include<ctime>
using namespace std;
#define enter puts("") 
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
#define forE(i, x, y) for(int i = head[x], y; ~i && (y = e[i].to); i = e[i].nxt)
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 2.5e5 + 5;
const int maxs = 30;
In ll read()
{
	ll ans = 0;
	char ch = getchar(), las = ' ';
	while(!isdigit(ch)) las = ch, ch = getchar();
	while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
	if(las == '-') ans = -ans;
	return ans;
}
In void write(ll x)
{
	if(x < 0) x = -x, putchar('-');
	if(x >= 10) write(x / 10);
	putchar(x % 10 + '0');
}
In void MYFILE()
{
#ifndef mrclr
	freopen(".in", "r", stdin);
	freopen(".out", "w", stdout);
#endif
}

int n;
char s[maxn];
struct Sam
{
	int tra[maxn << 1][maxs], len[maxn << 1], link[maxn << 1], cnt, las;
	In void init() {link[cnt = las = 0] = -1;}
	In void insert(int c)
	{
		int now = ++cnt, p = las; Mem(tra[now], 0);
		len[now] = len[las] + 1;
		while(~p && !tra[p][c]) tra[p][c] = now, p = link[p];
		if(p == -1) link[now] = 0;
		else
		{
			int q = tra[p][c];
			if(len[q] == len[p] + 1) link[now] = q;
			else
			{
				int clo = ++cnt; 
				len[clo] = len[p] + 1;
				memcpy(tra[clo], tra[q], sizeof(tra[q]));
				link[clo] = link[q], link[q] = link[now] = clo;
				while(~p && tra[p][c] == q) tra[p][c] = clo, p = link[p];
			}
		}
		las = now;
	}
	In int solve(char* s)
	{
		int ret = 0, n = strlen(s);
		for(int i = 0, p = 0, l = 0; i < n; ++i)
		{
			while(~p && !tra[p][s[i] - 'a']) p = link[p], l = len[p];
			if(p == -1) p = l = 0;
			else p = tra[p][s[i] - 'a'], ++l;
			ret = max(ret, l);
		}
		return ret;
	}
}S;

int main()
{
//	MYFILE();
	scanf("%s", s); n = strlen(s); 
	S.init();
	for(int i = 0; i < n; ++i) S.insert(s[i] - 'a'); 
	scanf("%s", s);
	write(S.solve(s)), enter;
	return 0;	
}

SP1812 LCS2
求多个串的lcs。
还是先把一个串建成后缀自动机。
然后对于每一个串,都放在后缀自动机上跑,记录在每一个节点能匹配的最大长度。然后这些长度取min,就是所有串在每一个节点能匹配的最大长度。最后答案遍历每一个节点取max即可。
不过还得想一想的是,如果这个节点成功匹配了,那么他的所有祖先节点显然也是匹配了的,但是却只标记了这个节点。所以在每一个串跑完后缀自动机后,从叶子节点把把标记在上传一遍,更新所有祖先节点的匹配情况。

#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
using namespace std;
#define enter puts("") 
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 1e5 + 5;
inline ll read()
{
  ll ans = 0;
  char ch = getchar(), last = ' ';
  while(!isdigit(ch)) last = ch, ch = getchar();
  while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
  if(last == '-') ans = -ans;
  return ans;
}
inline void write(ll x)
{
  if(x < 0) x = -x, putchar('-');
  if(x >= 10) write(x / 10);
  putchar(x % 10 + '0');
}

char s[maxn];
struct Sam
{
  int las, cnt;
  int tra[maxn << 1][30], len[maxn << 1], link[maxn << 1];
  In void init() {link[las = cnt = 0] = -1;}
  In void insert(int c)
  {
    int now = ++cnt, p = las;
    len[now] = len[las] + 1;
    while(~p && !tra[p][c]) tra[p][c] = now, p = link[p];
    if(p == -1) link[now] = 0;
    else
      {
    int q = tra[p][c];
    if(len[q] == len[p] + 1) link[now] = q;
    else
      {
        int clo = ++cnt;
        memcpy(tra[clo], tra[q], sizeof(tra[q]));
        len[clo] = len[p] + 1;
        link[clo] = link[q]; link[q] = link[now] = clo;
        while(~p && tra[p][c] == q) tra[p][c] = clo, p = link[p];
      }
      }
    las = now;
  }
  int buc[maxn << 1], pos[maxn << 1];
  In void sort()
  {
    for(int i = 1; i <= cnt; ++i) ++buc[len[i]];
    for(int i = 1; i <= cnt; ++i) buc[i] += buc[i - 1];
    for(int i = 1; i <= cnt; ++i) pos[buc[len[i]]--] = i;
  }
  int Max[maxn << 1], Min[maxn << 1];
  In void lcs(char* s)
  {
    int n = strlen(s);
    for(int i = 0, p = 0, l = 0; i < n; ++i)
      {
    int c = s[i] - 'a';
    while(~p && !tra[p][c]) p = link[p], l = len[p];
    if(p == -1) p = l = 0;
    else ++l, p = tra[p][c], Max[p] = max(Max[p], l);
      }
    for(int i = cnt; i; --i)
      {
    int now = pos[i], fa = link[now];
    Max[fa] = max(Max[fa], min(Max[now], len[fa]));
    Min[now] = min(Min[now], Max[now]); Max[now] = 0;
      }
  }
}S;

int main()
{
  //freopen("ha.in", "r", stdin);
  scanf("%s", s);
  int n = strlen(s); S.init();
  for(int i = 0; i < n; ++i) S.insert(s[i] - 'a');
  S.sort();
  Mem(S.Min, 0x3f); Mem(S.Max, 0);
  while(scanf("%s", s) != EOF) S.lcs(s);
  int ans = 0;
  for(int i = 1; i <= S.cnt; ++i) ans = max(ans, S.Min[i]);
  write(ans), enter;
  return 0;
}

[USACO06DEC]Milk Patterns
找出现了至少\(k\)次的最长的子串。
建完后缀自动机,求一遍子树大小,然后如果大于\(k\)就更新好了。

#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
#include<map>
using namespace std;
#define enter puts("") 
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 2e4 + 5;
inline ll read()
{
  ll ans = 0;
  char ch = getchar(), last = ' ';
  while(!isdigit(ch)) last = ch, ch = getchar();
  while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
  if(last == '-') ans = -ans;
  return ans;
}
inline void write(ll x)
{
  if(x < 0) x = -x, putchar('-');
  if(x >= 10) write(x / 10);
  putchar(x % 10 + '0');
}

int n, K;

struct Sam
{
  int las, cnt;
  map<int, int> tra[maxn << 1];
  int len[maxn << 1], link[maxn << 1], siz[maxn << 1];
  In void init() {link[las = cnt = 0] = -1;}
  In void insert(int x)
  {
    int now = ++cnt, p = las;
    len[now] = len[las] + 1; siz[now] = 1;
    while(~p && !tra[p].count(x)) tra[p][x] = now, p = link[p];
    if(p == -1) link[now] = 0;
    else
      {
    int q = tra[p][x];
    if(len[q] == len[p] + 1) link[now] = q;
    else
      {
        int clo = ++cnt;
        tra[clo] = tra[q]; len[clo] = len[p] + 1;
        link[clo] = link[q];
        link[q] = link[now] = clo;
        while(~p && tra[p][x] == q) tra[p][x] = clo, p = link[p];
      }
      }
    las = now;
  }
  int pos[maxn << 1]; 
  In int dfs()
  {
    int ret = 0;
    for(int i = 1; i <= cnt; ++i) pos[i] = i;   //sort版,还用了lambda表达式……
    sort(pos + 1, pos + cnt + 1, [=](int& a, int& b) {return len[a] > len[b];});
    for(int i = 1; i <= cnt; ++i)
      {
    siz[link[pos[i]]] += siz[pos[i]];
    if(siz[pos[i]] >= K) ret = max(ret, len[pos[i]]);
      }
    return ret;
  }
}S;

int main()
{
  n = read(), K = read();
  S.init();
  for(int i = 1, x; i <= n; ++i) x = read(), S.insert(x);
  write(S.dfs()), enter;
  return 0;
}

[TJOI2015]弦论
求第\(k\)小的子串。
因为每一个子串代表一条路径,所以我们求出从每一个节点开始有多少条路径,然后像平衡树找第\(k\)大的方法找即可。
题目还分了两种情况:\(t\)为0的话,每一个节点的endpos数量显然就是0了;\(t\)为1的话,每一个节点的endpos数量就是子树大小了。
至于怎么求从每一个点开始的路径数量,上面第二篇博客有讲。

#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
using namespace std;
#define enter puts("") 
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 5e5 + 5;
inline ll read()
{
  ll ans = 0;
  char ch = getchar(), last = ' ';
  while(!isdigit(ch)) last = ch, ch = getchar();
  while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
  if(last == '-') ans = -ans;
  return ans;
}
inline void write(ll x)
{
  if(x < 0) x = -x, putchar('-');
  if(x >= 10) write(x / 10);
  putchar(x % 10 + '0');
}

int Flg, K;
char s[maxn];
struct Sam
{
  int las, cnt;
  int tra[maxn << 1][30], len[maxn << 1], link[maxn << 1], siz[maxn << 1];
  In void init() {link[las = cnt = 1] = -1;}
  In void insert(int c)
  {
    int now = ++cnt, p = las;
    len[now] = len[las] + 1; siz[now] = 1;
    while(~p && !tra[p][c]) tra[p][c] = now, p = link[p];
    if(p == -1) link[now] = 1;
    else
      {
	int q = tra[p][c];
	if(len[q] == len[p] + 1) link[now] = q;
	else
	  {
	    int clo = ++cnt;
	    memcpy(tra[clo], tra[q], sizeof(tra[q]));
	    len[clo] = len[p] + 1;
	    link[clo] = link[q], link[q] = link[now] = clo;
	    while(~p && tra[p][c] == q) tra[p][c] = clo, p = link[p];
	  }
      }
    las = now;
  }
  int buc[maxn << 1], pos[maxn << 1], sum[maxn << 1];
  In void dfs()
  {
    for(int i = 1; i <= cnt; ++i) ++buc[len[i]];
    for(int i = 1; i <= cnt; ++i) buc[i] += buc[i - 1];
    for(int i = 1; i <= cnt; ++i) pos[buc[len[i]]--] = i;
    for(int i = cnt; i; --i) siz[link[pos[i]]] += siz[pos[i]];
    for(int i = 1; i <= cnt; ++i)
      {
	if(!Flg) sum[i] = siz[i] = 1;
	else sum[i] = siz[i];
      }
    siz[1] = 0;
    for(int i = cnt; i; --i)
      for(int j = 0; j < 26; ++j)
	if(tra[pos[i]][j]) sum[pos[i]] += sum[tra[pos[i]][j]];
  }
  In void print(int k)
  {
    if(sum[1] < k) {write(-1); return;}
    int now = 1;
    k -= siz[now];
    while(k)
      {
	int c = 0;
	while(k > sum[tra[now][c]]) k -= sum[tra[now][c++]];
	now = tra[now][c];
	putchar('a' + c); k -= siz[now];
      }
  }
}S;

int main()
{
  scanf("%s", s);
  int n = strlen(s); S.init();
  for(int i = 0; i < n; ++i) S.insert(s[i] - 'a');
  Flg = read(), K = read();
  S.dfs();
  S.print(K), enter;
  return 0;
}
posted @ 2019-02-28 16:08  mrclr  阅读(475)  评论(0编辑  收藏  举报