bzoj 3676 后缀自动机+马拉车+树上倍增
思路:用马拉车把一个串中的回文串个数降到O(n)级别,然后每个串在后缀自动机上倍增找个数。
#include<bits/stdc++.h> #define LL long long #define fi first #define se second #define mk make_pair #define PII pair<int, int> #define PLI pair<LL, int> #define ull unsigned long long using namespace std; const int N = 300000 + 7; const int inf = 0x3f3f3f3f; const LL INF = 0x3f3f3f3f3f3f3f3f; const int mod = 1e9 + 7; const double eps = 1e-8; const int base = 87; int n, m, p[N<<1]; char s[N<<1]; struct SuffixAutomaton { int last, cur, cnt, ch[N<<1][26], id[N<<1], fa[N<<1], dis[N<<1], sz[N<<1], c[N]; int f[N<<1][20], pos[N<<1]; SuffixAutomaton() {cur = cnt = 1;} void init() { for(int i = 1; i <= cnt; i++) { memset(ch[i], 0, sizeof(ch[i])); sz[i] = c[i] = dis[i] = fa[i] = 0; } cur = cnt = 1; } void extend(int c, int id) { last = cur; cur = ++cnt; int p = last; dis[cur] = id; for(; p && !ch[p][c]; p = fa[p]) ch[p][c] = cur; if(!p) fa[cur] = 1; else { int q = ch[p][c]; if(dis[q] == dis[p]+1) fa[cur] = q; else { int nt = ++cnt; dis[nt] = dis[p]+1; memcpy(ch[nt], ch[q], sizeof(ch[q])); fa[nt] = fa[q]; fa[q] = fa[cur] = nt; for(; ch[p][c]==q; p=fa[p]) ch[p][c] = nt; } } sz[cur] = 1; } void getSize(int n) { for(int i = 1; i <= cnt; i++) c[dis[i]]++; for(int i = 1; i <= n; i++) c[i] += c[i-1]; for(int i = cnt; i >= 1; i--) id[c[dis[i]]--] = i; for(int i = cnt; i >= 1; i--) { int p = id[i]; sz[fa[p]] += sz[p]; } } LL query(int p, int len) { for(int j = 19; j >= 0; j--) { if(f[p][j] && dis[f[p][j]] >= len) p = f[p][j]; } return 1ll*len*sz[p]; } void solve() { for(int i = 1, p = 1; i <= n; i++) p = ch[p][s[i]-'a'], pos[i] = p; for(int i = 1; i <= cnt; i++) f[i][0] = fa[i]; for(int j = 1; j < 20; j++) for(int i = 1; i <= cnt; i++) f[i][j] = f[f[i][j-1]][j-1]; LL ans = 0; s[0] = '-', s[n+1] = '+'; int mx = 0, id = 0; for(int i = 1; i <= n; i++) { if(mx > i) p[i] = min(mx-i, p[2*id-i]); else p[i]=1, ans = max(ans, query(pos[i], 1)); while(s[i+p[i]]==s[i-p[i]]) p[i]++, ans = max(ans, query(pos[i+p[i]-1], 2*p[i]-1)); if(i+p[i]>mx) mx = i+p[i], id = i; } mx = 0, id = 0; for(int i = 1; i <= n; i++) { if(mx > i) p[i] = min(mx-i, p[2*id-i]); else p[i] = 0; while(s[i+p[i]+1]==s[i-p[i]]) p[i]++, ans = max(ans, query(pos[i+p[i]], 2*p[i])); if(i+p[i]>mx) mx = i+p[i], id = i; } printf("%lld\n", ans); } } sam; int main() { scanf("%s", s + 1); n = strlen(s + 1); for(int i = 1; i <= n; i++) sam.extend(s[i]-'a', i); sam.getSize(n); sam.solve(); return 0; } /* */