P6139 【模板】广义后缀自动机(广义 SAM)
题目链接
P6139 【模板】广义后缀自动机(广义 SAM)
【模板】广义后缀自动机(广义 SAM)
题目描述
给定 \(n\) 个由小写字母组成的字符串 \(s_1,s_2\ldots s_n\),求本质不同的子串个数。(不包含空串)
输入格式
第一行一个正整数 \(n\)。
以下 \(n\) 行,每行一个字符串,第 \(i\) 行表示字符串 \(s_{i-1}\)。
输出格式
一行一个正整数,表示答案。
样例输入
4
aa
ab
bac
caa
样例输出
10
提示
数据范围:\(1\le n\le 4\cdot 10^5\),\(1\le \sum{|s_i|}\le 10^6\)。
样例解释:共有 \(10\) 个本质不同的子串,它们是:"a","b","c","aa","ab","ac","ba","ca","bac","caa"
。
解题思路
后缀自动机
求 \(n\) 个字符串的子串数量,SAM 可求解一个字符串的本质不同的子串的数量,可将 \(n\) 个字符串合为一个字符串,中间用分隔字符标识即可,由于后缀自动机表示的有向无环图中的任意一条路径都与一个子串一一对应,所有的路径条数即本质不同的子串数量,但里面还有分割字符,每次扩展时忽略该分割字符即可
注意:最后还要减一,由于根节点是虚根节点,并没有贡献
- 时间复杂度:\(O(n)\)
广义后缀自动机
代码
- 后缀自动机
// Problem: P6139 【模板】广义后缀自动机(广义 SAM)
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P6139
// Memory Limit: 500 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
// %%%Skyqwq
#include <bits/stdc++.h>
// #define int long long
#define help {cin.tie(NULL); cout.tie(NULL);}
#define pb push_back
#define fi first
#define se second
#define mkp make_pair
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
typedef pair<LL, LL> PLL;
template <typename T> bool chkMax(T &x, T y) { return (y > x) ? x = y, 1 : 0; }
template <typename T> bool chkMin(T &x, T y) { return (y < x) ? x = y, 1 : 0; }
template <typename T> void inline read(T &x) {
int f = 1; x = 0; char s = getchar();
while (s < '0' || s > '9') { if (s == '-') f = -1; s = getchar(); }
while (s <= '9' && s >= '0') x = x * 10 + (s ^ 48), s = getchar();
x *= f;
}
const int N=3e6+5;
char s[N];
int n,cnt=1,lst=1;
LL f[N];
struct Node
{
int len,fa;
int ch[27];
}node[N];
void extend(int c)
{
int p=lst,np=lst=++cnt;
node[np].len=node[p].len+1;
for(;p&&!node[p].ch[c];p=node[p].fa)node[p].ch[c]=np;
if(!p)node[np].fa=1;
else
{
int q=node[p].ch[c];
if(node[q].len==node[p].len+1)node[np].fa=q;
else
{
int nq=++cnt;
node[nq]=node[q];
node[nq].len=node[p].len+1;
node[q].fa=node[np].fa=nq;
for(;p&&node[p].ch[c]==q;p=node[p].fa)node[p].ch[c]=nq;
}
}
}
LL dfs(int x)
{
if(f[x]!=-1)return f[x];
f[x]=1;
for(int i=0;i<26;i++)
if(node[x].ch[i])
f[x]+=dfs(node[x].ch[i]);
return f[x];
}
int main()
{
scanf("%d",&n);
while(n--)
{
scanf("%s",s);
int m=strlen(s);
for(int i=0;i<m;i++)extend(s[i]-'a');
extend(26);
}
memset(f,-1,sizeof f);
cout<<dfs(1)-1;
return 0;
}