杭电多校HDU 6599 I Love Palindrome String (回文树)题解
题意:
定义一个串为\(super\)回文串为:
\(\bullet\) 串s为主串str的一个子串,即\(s = str_lstr_{l + 1} \cdots str_r\)
\(\bullet\) 串s为回文串
\(\bullet\) 串\(str_lstr_{l + 1}...str_{\llcorner (l + r) / 2 \lrcorner}\)也是回文串
问长度为1、2、3 \(\cdots n\)的\(super\)回文串分别出现了几次
思路:
回文树建一下,然后每次新建一个节点的时候用hash快速判断一下是不是\(super\)回文串,然后回文树统计一下个数。
代码:
#include<map>
#include<set>
#include<cmath>
#include<cstdio>
#include<stack>
#include<ctime>
#include<vector>
#include<queue>
#include<cstring>
#include<string>
#include<sstream>
#include<iostream>
#include<algorithm>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int maxn = 3e5 + 5;
const int INF = 0x3f3f3f3f;
const ull seed = 131;
const ll MOD = 1e9 + 7;
using namespace std;
ull ha[maxn], fac[maxn];
int ans[maxn];
ull getstring(int l, int r){
return ha[r] - ha[l - 1] * fac[r - l + 1];
}
struct PAM{
int nex[maxn][26]; //指向的一个字符的节点
int fail[maxn]; //失配节点
int len[maxn]; //当前节点回文长度
int str[maxn]; //当前添加的字符串
int cnt[maxn]; //节点出现次数
int last;
int tot; //PAM中节点数
int N; //添加的串的个数
int satisfy[maxn];
int newnode(int L){
for(int i = 0; i < 26; i++) nex[tot][i] = 0;
len[tot] = L;
cnt[tot] = 0;
return tot++;
}
void init(){
tot = 0;
newnode(0);
newnode(-1);
last = 0;
N = 0;
str[0] = -1;
fail[0] = 1;
}
int getfail(int x){
while(str[N - len[x] - 1] != str[N]) x = fail[x];
return x;
}
void add(char ss){
int c = ss - 'a';
str[++N] = c;
int cur = getfail(last);
if(!nex[cur][c]){
int now = newnode(len[cur] + 2);
fail[now] = nex[getfail(fail[cur])][c];
nex[cur][c] = now;
int need = (len[now] + 1) / 2;
if(len[now] == 1 || getstring(N - len[now] + 1, N - len[now] + need) == getstring(N - need + 1, N)) satisfy[now] = 1;
else satisfy[now] = 0;
}
last = nex[cur][c];
cnt[last]++;
}
void count(){
for(int i = tot - 1; i >= 0; i--){
cnt[fail[i]] += cnt[i];
if(satisfy[i]) ans[len[i]] += cnt[i];
}
}
}pa;
char s[maxn];
int main(){
fac[0] = 1;
for(int i = 1; i < maxn; i++) fac[i] = fac[i - 1] * seed;
while(~scanf("%s", s + 1)){
pa.init();
int len = strlen(s + 1);
ha[0] = 1;
for(int i = 1; i <= len; i++){
ha[i] = ha[i - 1] * seed + s[i];
}
for(int i = 1; i <= len; i++) ans[i] = 0;
for(int i = 1; i <= len; i++){
pa.add(s[i]);
}
pa.count();
for(int i = 1; i <= len; i++){
if(i != 1) printf(" ");
printf("%d", ans[i]);
}
puts("");
}
return 0;
}