回文树
PAM
用以处理回文串问题的一类自动机
每个节点代表一类回文串
节点信息:
回文串长度,fail指针,子节点,出现次数等等
初始化
初始化回文串的时,建立两个节点,长度分别为\(-1\)和\(0\),代表奇数回文串和偶数回文串,并标记偶数节点的\(fail\)为奇数节点【当任意长度的回文串都不存在时,那么节点自身就作为它结尾的最长回文串】
同时记一个\(last\)为上一个插入的节点
PAM(){len[siz = 1] = -1,s[0] = -1,fail[0] = 1;}
插入操作
不断沿\(last\)的\(fail\)指针往上找,直至找到\(s[n - 1 - len[u]] = s[n]\)的节点\(u\),便可插入其后
设置\(fail\)指针时,同样沿\(fail\)指针查找,直至找到\(s[n - 1 - len[u]] = s[n]\)的节点\(u\),作为\(fail\)
int find(int u){
while (s[n - 1 - len[u]] != s[n]) u = fail[u];
return u;
}
void ins(int x){
s[++n] = x; int cur = find(last);
if (!ch[cur][x]){
fail[++siz] = ch[find(fail[cur])][x];
ch[cur][x] = siz; len[siz] = len[cur] + 2;
}
last = ch[cur][x]; cnt[last]++;
}
统计
每个节点出现次数还与其儿子节点有关,所以统计时还要加上子树的贡献
LL count(){
LL re = 0;
for (int i = siz; ~i; i--){
cnt[fail[i]] += cnt[i];
re = max(re,1ll * cnt[i] * len[i]);
}
return re;
}
这样就可以\(A\)掉这道题辣
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<cstdio>
#include<vector>
#include<queue>
#include<cmath>
#include<map>
#define LL long long int
#define REP(i,n) for (int i = 1; i <= (n); i++)
#define Redge(u) for (int k = h[u],to; k; k = ed[k].nxt)
#define cls(s,v) memset(s,v,sizeof(s))
#define mp(a,b) make_pair<int,int>(a,b)
#define cp pair<int,int>
using namespace std;
const int maxn = 300005,maxm = 100005,INF = 0x3f3f3f3f;
inline int read(){
int out = 0,flag = 1; char c = getchar();
while (c < 48 || c > 57){if (c == '-') flag = 0; c = getchar();}
while (c >= 48 && c <= 57){out = (out << 1) + (out << 3) + c - 48; c = getchar();}
return flag ? out : -out;
}
char s[maxn];
int n;
struct PAM{
int n,last,siz,ch[maxn][26],fail[maxn],cnt[maxn],len[maxn],s[maxn];
PAM(){len[siz = 1] = -1,s[0] = -1,fail[0] = 1;}
int find(int u){
while (s[n - 1 - len[u]] != s[n]) u = fail[u];
return u;
}
void ins(int x){
s[++n] = x; int cur = find(last);
if (!ch[cur][x]){
fail[++siz] = ch[find(fail[cur])][x];
ch[cur][x] = siz; len[siz] = len[cur] + 2;
}
last = ch[cur][x]; cnt[last]++;
}
LL count(){
LL re = 0;
for (int i = siz; ~i; i--){
cnt[fail[i]] += cnt[i];
re = max(re,1ll * cnt[i] * len[i]);
}
return re;
}
}T;
int main(){
scanf("%s",s + 1); n = strlen(s + 1);
REP(i,n) T.ins(s[i] - 'a');
printf("%lld\n",T.count());
return 0;
}