bzoj 2342: 双倍回文 回文自动机

题目大意:

定义双倍回文串的左一半和右一半均是回文串的长度为4的倍数的回文串
求一个给定字符串中最长的双倍回文串的长度

题解:

  1. 我们知道可以简单地判定以某一点结尾的最长回文串
  2. 我们知道可以简单地判定以某一点开头的最长回文串

啥?第二个?你把串倒过来不就行了?

所以我们枚举双倍回文串的断点再判定即可.
我们发现我们每次都要取枚举到的两个端点的最长的相同偶数长度的回文串
并且这两个回文串还要相同。。也就是说在回文自动机上这是同一个点
所以我们在fail树上求lca即可

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <map>
using namespace std;
typedef long long ll;
inline void read(int &x){
	x=0;char ch;bool flag = false;
	while(ch=getchar(),ch<'!');if(ch == '-') ch=getchar(),flag = true;
	while(x=10*x+ch-'0',ch=getchar(),ch>'!');if(flag) x=-x;
}
inline int cat_max(const int &a,const int &b){return a>b ? a:b;}
inline int cat_min(const int &a,const int &b){return a<b ? a:b;}
const int maxn = 1000010;
struct Edge{
	int to,next;
}G[maxn];
int head[maxn],cnt;
void add(int u,int v){
	G[++cnt].to = v;
	G[cnt].next = head[u];
	head[u] = cnt;
}
#define v G[i].to
int son[maxn],top[maxn],siz[maxn],dep[maxn],fa[maxn];
void dfs(int u){
	siz[u] = 1;
	for(int i = head[u];i;i=G[i].next){
		if(v == fa[u]) continue;
		fa[v] = u;
		dep[v] = dep[u] + 1;
		dfs(v);
		siz[u] += siz[v];
		if(son[u] == -1 || siz[son[u]] < siz[v]) son[u] = v;
	}
}
void dfs(int u,int tp){
	top[u] = tp;
	if(son[u] != -1) dfs(son[u],tp);
	for(int i = head[u];i;i=G[i].next){
		if(v == son[u] || v == fa[u]) continue;
		dfs(v,v);
	}
}
#undef v
inline int lca(int u,int v){
	while(top[u] != top[v]){
		if(dep[top[u]] < dep[top[v]]) swap(u,v);
		u = fa[top[u]];
	}return dep[u] < dep[v] ? u : v;
}
struct Node{
	map<int,int>nx;
	int fail,len;
}T[maxn];
int mp[maxn],last,nodecnt,len,str[maxn];
int f[maxn];
inline void init(){
	T[last = nodecnt = 0].fail = 1;
	T[++nodecnt].len = -1;
	str[len=0] = -1;add(1,0);
	f[0] = f[1] = -1;
}
char s[maxn];
inline void insert(int i){
	int c = s[i] - 'a',cur,p,x;str[++len] = c;
	for(p = last;str[len-T[p].len-1] != str[len];p = T[p].fail);
	if(T[p].nx[c] == 0){
		T[cur = ++ nodecnt].len = T[p].len + 2;
		for(x = T[p].fail;str[len-T[x].len-1] != str[len];x = T[x].fail);
		T[cur].fail = T[x].nx[c];T[p].nx[c] = cur;
		f[cur] = f[T[cur].fail];
		if(T[cur].len % 2 == 0) f[cur] = cur;
		add(T[cur].fail,cur);
	}mp[i] = last = T[p].nx[c];
}
int main(){init();
	memset(fa,-1,sizeof fa);
	memset(son,-1,sizeof son);
	int n;read(n);scanf("%s",s);
	s[n] = 'z' + 1;s[n+1] = 'z' + 2;
	for(int i=0;i<n;++i){
		s[(n+1)+(n-i)] = s[i];
	}
	int len = (n << 1) + 2;
	for(int i=0;i<len;++i) insert(i);
	dfs(1);dfs(1,1);int ans = 0;
	for(int i=0;i<n-1;++i){
		int x = lca(mp[i],mp[(n+1)+(n-i-1)]);
		if(f[x] != -1) ans = max(ans,T[f[x]].len<<1);
	}printf("%d\n",ans);
	getchar();getchar();
	return 0;
}
posted @ 2017-03-13 07:10  Sky_miner  阅读(243)  评论(0编辑  收藏  举报