题解 前缀
对每个前缀做一次全文匹配吗?
容易想到 AC 自动机,并且可以发现就是对 fail 树求 \(\sum siz_i\)
而且发现 AC 自动机上只有一个模式串
直接建空间开不下,需要每次跳 fail
那复杂度是什么呢?
- 只有一个模式串的 AC 自动机就是 KMP:因为我四十分钟都没意识到这个事所以还是写一下
我一点也不生气.jpg
于是易证复杂度是 \(O(n)\)
所以只有常数个模式串的 AC 自动机在建的时候跳 fail 来建的复杂度都是 \(O(n)\)
AC 自动机复杂度怎么证来着?
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 1000010
#define M 10000010
#define fir first
#define sec second
#define ll long long
//#define int long long
int n;
char s[M];
namespace force{
ll ans;
int tr[N][26], fail[N], head[N], siz[N], tot, ecnt;
struct edge{int to, next;}e[N];
inline void add(int s, int t) {e[++ecnt]={t, head[s]}; head[s]=ecnt;}
void build() {
queue<int> q;
int u=0;
for (int i=0; i<26; ++i)
if (tr[u][i]) q.push(tr[u][i]), fail[tr[u][i]]=u, add(u, tr[u][i]);
else tr[u][i]=tr[fail[u]][i];
while (q.size()) {
u=q.front(); q.pop();
for (int i=0; i<26; ++i)
if (tr[u][i]) q.push(tr[u][i]), fail[tr[u][i]]=tr[fail[u]][i], add(tr[fail[u]][i], tr[u][i]);
else tr[u][i]=tr[fail[u]][i];
}
}
void dfs(int u) {
siz[u]=1;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
dfs(v);
siz[u]+=siz[v];
}
if (u) ans+=siz[u];
}
void solve() {
// cout<<double(sizeof(tr)+sizeof(fail)*5)/1000/1000<<endl;
memset(head, -1, sizeof(head));
int p=0, *t;
for (int i=1; i<=n; ++i,p=*t) {
t=&tr[p][s[i]-'a'];
if (!*t) *t=++tot;
}
build();
dfs(0);
printf("%lld\n", ans);
}
}
namespace task1{
ll ans;
pair<int, int> tr[M];
int fail[M], head[M], siz[M], tot, ecnt;
struct edge{int to, next;}e[M];
inline void add(int s, int t) {e[++ecnt]={t, head[s]}; head[s]=ecnt;}
void build() {
queue<int> q;
fail[tr[0].sec]=0, q.push(tr[0].sec), add(0, tr[0].sec);
while (q.size()) {
int u=q.front(); q.pop();
if (!tr[u].sec) continue;
int now=fail[u];
while (now && tr[now].fir!=tr[u].fir) now=fail[now];
if (tr[now].fir==tr[u].fir) fail[tr[u].sec]=tr[now].sec, add(tr[now].sec, tr[u].sec);
else fail[tr[u].sec]=0, add(0, tr[u].sec);
q.push(tr[u].sec);
}
}
void dfs(int u) {
// cout<<"dfs: "<<u<<endl;
siz[u]=1;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
dfs(v);
siz[u]+=siz[v];
}
if (u) ans+=siz[u];
}
void solve() {
// cout<<double(sizeof(tr)+sizeof(fail)*5)/1000/1000<<endl;
memset(head, -1, sizeof(head));
int p=0, *t;
for (int i=1; i<=n; ++i,p=tr[p].sec)
tr[p]={s[i]-'a', ++tot};
build();
dfs(0);
printf("%lld\n", ans);
}
}
signed main()
{
freopen("pre.in", "r", stdin);
freopen("pre.out", "w", stdout);
scanf("%s", s+1);
n=strlen(s+1);
if (n<N) force::solve();
else task1::solve();
return 0;
}