manacher算法模板
manacher模板
今天考了个回文的题,于是在520巨佬的指导下学习了一波manacher.先推荐一波520大佬的博客
题目描述
给出一个只由小写英文字符a,b,c...y,z组成的字符串S,求S中最长回文串的长度.
字符串长度为n
输入输出格式
输入格式:
一行小写英文字符a,b,c...y,z组成的字符串S
输出格式:
一个整数表示答案
输入输出样例
输入样例#1:
aaa
输出样例#1:
3
说明
字符串长度len <= 11000000
题意就是要求一段字符串中的最长回文串.
首先想一下朴素算法是如何实现的:我们可以枚举一个字串的起点终点,然后用\(O(len)\)的时间验证,这样的总时间复杂度是\(O(n^3)\)的.
根据回文的性质,我们可以不用又枚举起点又枚举终点,因为回文是关于它的对称轴对称的,所以我们可以考虑直接枚举它的对称轴,然后向两边扩展.这样的总时间复杂度是\(O(n^2)\)的.
因为一个字符串最多只有\(len\)个对称轴,所以对称轴相同的那些回文串都可以由对称轴这一点拓展而来. 但是因为对称轴有可能是在两个字符的中间(也就是这个回文串的长度是偶数),这样会使得接下来的操作很不方便,所以我们将原字符串中每两个字符的中间都插入一个同一个特殊符号方便判断,比如'#'什么的.打个栗子:
Brave_Cattle -> #B#r#a#v#e#_#C#a#t#t#l#e#
比如这个两个t,显然它们的对称轴是在这两个字符中间的,那么这个插入的特殊符号就给我们判断省了很多事.
为了求出长度最长的回文串,显然我们要求出回文串半径最大的那一个.
这时候我们需要一个数组\(p[i]\)来记录下从下标\(i\)开始最多能拓展的回文半径. 因为在每两个字符中间都插入了一个特殊字符所以记录下所有\(p[i]\)之后最长的字串的长度就是最大的\(p[i]-1\).
那么问题的重点就来了:如何快速求出\(p[i]\)数组. 我们可以发现,如果有一个回文串是一个长回文串的子串的话,那么这个回文串的长度可以直接由之前记录的对称轴另一边的那个对称的字串推出(这个一定要看图理解).
我们用\(id\)表示目前作为中间回文来推其他回文串的半径的那个回文的对称轴的下标(也就是表示的从下标 \(mx\)关于\(id\)的对称点 到下标\(mx\) 的那个回文串,\(mx\)表示\(p[id]\),也就是该回文串的最右边的下标).在图中就是最下面这根直线所代表的区间,我们叫它\(A\)区间.
此时有一个以\(i\)为对称轴的回文字串,也就是图中\(i\)下面那条线所代表的区间范围,我们叫它\(I\)区间.根据数学知识,我们可以得到\(i\)关于\(id\)的对称点\(j=id*2-i\)(如果不知道就自己画根数轴模拟一下吧),以这个对称轴得到的回文串我们叫它\(J\)区间.因为\(I,J\)区间都是属于\(A\)区间的回文串,而且他们关于\(id\)对称,所以这两个区间的半径的长度是一样的.
if(i < mx) p[i] = p[id*2-i];
但是我们还需要考虑一种情况: 如果\(I\)区间的右端点超过了\(A\)区间,那么此时\(I\)区间的半径是取不到和\(J\)区间一样大的半径的,所以我们需要判断是否会发生这种情况.如下图:
if(i < mx) p[i] = min(p[pos*2-i],mx-i);
当然这样得到了\(p[i]\)还需要再往后扩展一下,因为有可能后面还存在可以扩展的情况.就判断一下是否处理后的字符串的第\(i-p[i]\)位和第\(i+p[i]\)位是否相同.
最后我们还需要在循环的时候更新一下作为长串来处理的那个回文串.我们要选最远的那个,因为这样就可以让循环的次数减少.
下面看一下终点的代码注释吧:
void manacher(){
int pos = 0, mx = 0, ans = 0;
for(int i=1;i<=cnt;i++){
if(i < mx) p[i] = min(p[pos*2-i],mx-i);//处理
else p[i] = 1;//否则以i为对称轴的这个串就属于长串的范围外,无法直接得到p[i]值.
while(ss[i+p[i]] == ss[i-p[i]]) p[i]++;//还要再扩展一次
if(mx < i+p[i]) mx = i+p[i], pos = i;//更新作为长串的回文串
ans = max(ans,p[i]-1);
}
}
大概内容就讲的差不多了...如果不懂的话可能只能再多自己出一点数据模拟一下这个过程了.
下面贴一下完整代码吧
#include<bits/stdc++.h>
using namespace std;
const int N=11000000+5;
const int inf=2147483647;
int cnt, len, ans = 0;
char s[N], ss[N*2];
int p[N*2];
void init(){//将每两个字符中插入一个字符
len = strlen(s), cnt = 1;
ss[0] = '!'; ss[cnt] = '#';
for(int i=0;i<len;i++)
ss[++cnt] = s[i], ss[++cnt] = '#';
}
void manacher(){
int pos = 0, mx = 0;
for(int i=1;i<=cnt;i++){
if(i < mx) p[i] = min(p[pos*2-i],mx-i);
else p[i] = 1;
while(ss[i+p[i]] == ss[i-p[i]]) p[i]++;
if(mx < i+p[i]) mx = i+p[i], pos = i;
ans = max(ans,p[i]-1);
}
}
int main(){
scanf("%s",s);
init(); manacher();
printf("%d\n",ans);
return 0;
}