回文自动机学习笔记
前言
考试遇到了新科技,所以不得不学一下。
根据各种自动机的姿势,PAM 是处理回文问题的有力工具。
OI-Wiki 讲的是好的,推荐学习。
PAM 的实用性在于,它用至多 \(O(n)\) 个节点存储了所有的回文串,Manacher 与之相比则有很大局限性。
构树
Manacher 利用插入 #
的方法避免的回文串长度奇偶性的讨论,因为这个讨论真的很麻烦。
而 PAM 的方法为,建立两个根,分别为奇根和偶根。
每个节点承载的信息,基本的 PAM 需要记录 \(len(i),fail(i)\) 分别表示节点 \(i\) 代表的回文后缀的长度和失配指针。
假设奇根的下标为 \(1\),偶根的下标为 \(0\),那么 \(len(0)=0,fail(0)=1,len(1)=-1\)。
而奇根不需要 \(fail\) 指针,因为每个字符一定能和自己形成长度为 \(1\) 的回文串。
每个 \(fail\) 指针实际指向的是自己的最长回文后缀,这个和大部分自动机是一样的。
考虑每次加入一个字符,都利用上一次的插入位置,跳 \(fail\) 来匹配回文串。
考虑在两头加字符,这时候奇根的 \(-1\) 就有很好的优势了,代码:
int Node(int l){
tot ++;
memset(ch[tot], 0, sizeof(ch[tot]));
len[tot] = l;
fail[tot] = 0;
return tot;
}
void Build(){
tot = - 1;
n = las = 0;
Node(0), Node(-1);
fail[0] = 1;
}
int Get_Fail(int x){
while(str[n - len[x] - 1] != str[n]) x = fail[x];
return x;
}
void Insert(char c){
str[++ n] = c;
int now = Get_Fail(las);
if(!ch[now][c - 'a']){
int x = Node(len[now] + 2);
fail[x] = ch[Get_Fail(fail[now])][c - 'a'];
ch[now][c - 'a'] = x;
}
las = ch[now][c - 'a'];
}
性质&应用
可以证明一个字符串本质不同的回文串个数至多为 \(O(|S|)\) 个,所以回文树节点个数是 \(O(n)\) 的。
如果需要统计一个字符串的本质不同回文子串个数,那么就是自动机的状态数。
而且可以在自动机上 DP 或利用它 DAG 的本质维护些奇奇怪怪的东西,反正都是线性的。
啥你需要模板题。
直接从最长后缀开始沿着 \(fail\) 指针遍历即可。
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long LL;
const int N = 3e5 + 10;
int n, tot, las, ch[N][30], len[N], cnt[N], fail[N];
char str[N];
int read(){
int x = 0, f = 1; char c = getchar();
while(c < '0' || c > '9') f = (c == '-') ? -1 : 1, c = getchar();
while(c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar();
return x * f;
}
int Node(int l){
tot ++;
memset(ch[tot], 0, sizeof(ch[tot]));
len[tot] = l;
cnt[tot] = fail[tot] = 0;
return tot;
}
void Build(){
tot = - 1;
n = las = 0;
Node(0), Node(-1);
fail[0] = 1;
}
int Get_Fail(int x){
while(str[n - len[x] - 1] != str[n]) x = fail[x];
return x;
}
void Insert(char c){
str[++ n] = c;
int now = Get_Fail(las);
if(!ch[now][c - 'a']){
int x = Node(len[now] + 2);
fail[x] = ch[Get_Fail(fail[now])][c - 'a'];
ch[now][c - 'a'] = x;
}
las = ch[now][c - 'a'];
cnt[las] ++;
}
int main(){
scanf("%s", str + 1);
Build(); int t = strlen(str + 1);
for(int i = 1; i <= t; i ++) Insert(str[i]);
LL ans = 0;
for(int i = tot; i >= 2; i --){
if(fail[i] > 1)
cnt[fail[i]] += cnt[i];
ans = max(ans, 1LL * cnt[i] * len[i]);
}
printf("%lld\n", ans);
return 0;
}