PAM(回文自动机)学习笔记
PAM学习笔记
一、概述
感觉\(PAM\)比\(SAM\)什么的好理解多了。。。
顾名思义,回文自动机\(PAM\)就是一个能够高效处理一个字符串所有的回文子串的自动机,也就是一个字符串所有回文子串的信息的高度压缩得到的结果。因此,\(PAM\)能够方便地解决一系列关于回文串的问题。
二、结构
回文自动机维护了原串上的所有本质不同的回文串。
回文自动机的结构可以看成是两棵树,一棵的根是奇根\(odd\),代表着一个长度为\(-1\)的实际上不存在的回文串,存储所有长度为奇数的回文串,一棵的根是偶根\(even\),代表着长度为\(0\)的回文串,存储所有长度为偶数的回文串。
自动机的每一个转移对应树上的一条边,同时带有一个字符\(c\),如果这条边连接\(u\rightarrow v\),那么\(s_v\)就是\(s_u\)在前后分别加上字符\(c\)组成的,\(s_u\)是\(u\)节点维护的回文串。显然,这样做可以表示出所有的回文子串。
与\(AC\)自动机类似,\(PAM\)也带有\(fail\)边,在这里每个点的\(fail\)连向它的最长回文后缀对应的点,同时我们特判\(fail[even]=fail[odd]=odd\),同样的\(fail\)指针也能构成一棵树
翁文涛论文《回文树及其应用》中的示例图:
三、对线性状态数与转移数的证明
-
定理:对于一个字符串\(s\),它的本质不同的回文子串至多有\(|s|\)个
-
证明:考虑数学归纳法
-
当\(|s|=1\)时定理显然成立,当\(|s|>1\)时,设\(s=tc\),假设结论对\(t\)成立,那么增加\(c\)后设新增的回文串的左端点从小到大排序后依次为\(l_1,l_2,\dots l_k\),于是有\(s[l_1\dots |s|]\)是回文串
-
那么我们就有\(s[l_i\dots |s|]=s[l_1,|s|-l_1+l_i]\),当\(i\not= 1\)时\(|s|-l_1+l_i\le|t|\),也就是说\(s[l_i\dots|s|]\)这个回文串已经在\(t\)中出现过了。因此,每增加一个字符,本质不同的回文子串数量最多增加\(1\)。
-
由此可知该定理成立
-
-
因为\(PAM\)中的节点个数\(=\)本质不同的回文子串的个数\(+2\)(去掉\(odd\)与\(even\)),于是\(PAM\)的状态数是线性的,每个状态由唯一的转移边转移而来,因此转移数也是线性的。
四、构造
对于每一个节点,我们需要维护它对应回文串的长度\(len\)以及\(fail\),并用类似\(Trie\)树的方法维护它的儿子,我们用\(0\)表示\(even\),\(1\)表示\(odd\),初始化\(len[0]=0,len[1]=-1\)
采用增量法构造\(PAM\),因为上面的定理,我们每次插入最多只会增加一个节点,即新字符串的最长回文后缀。记\(last\)为上次插入后字符串的最长回文后缀对应的点,初始为\(0\),设我们插入的是第\(i\)个字符\(u\),那么插入后新字符串的最长回文后缀,应当是由原来的一个点\(x\)通过\(u\)转移而来,并且必须满足\(s[i-len[x]-1]=u\),这样就能满足它是回文串了。
找到\(x\)的方法则是从\(last\)开始一路跳\(fail\)直到满足条件为止。设新点为\(tot\),那么\(len[tot]=len[x]+2\)
新点的\(fail\)怎么求呢?如果\(x\)是奇根,那么我们直接将\(fail\)指向\(0\),否则我们同样从\(fail[x]\)开始一路跳\(fail\)直到找到一个满足\(s[i-len[x]-1]=u\)的\(x\),那么\(x\)的\(u\)儿子就是新点的\(fail\)了
对于复杂度证明,\(PAM\)主要看起来不显然是线性的操作就是跳\(fail\),注意到新找到的点的深度最多\(+1\),于是构造的复杂度也是\(\mathcal O(|s|)\)的。
五、例题
-
例题:洛谷模板题
-
即求本质不同的回文子串个数,就等于回文自动机节点数\(-2\)。(去掉奇根与偶根)
-
代码:
-
-
#include<bits/stdc++.h> using namespace std; const int N=5e5+10; char s[N]; int n; namespace PAM{ int fail[N],len[N],dep[N],last,tot;//dep为fail树的深度 int ch[N][26]; inline void init(){ len[0]=0;len[1]=-1;//1为odd,0为even fail[0]=1;fail[1]=1; last=0;tot=1; } inline void newnode(int x,int p,int v){ ch[p][v]=++tot;len[tot]=len[p]+2; if(p==1) fail[tot]=0; else{ p=fail[p]; while(s[x-len[p]-1]!=s[x]) p=fail[p]; fail[tot]=ch[p][v]; } dep[tot]=dep[fail[tot]]+1; } inline int insert(int x){ int p=last,v=s[x]-'a'; while(s[x-len[p]-1]!=s[x]) p=fail[p];//跳fail直至找到能增加转移v的点 if(!ch[p][v]) newnode(x,p,v); last=ch[p][v]; return dep[last]; } } int main(){ scanf("%s",s+1); n=strlen(s+1); PAM::init(); int lastans=0; for(int i=1;i<=n;++i){ s[i]=(s[i]-97+lastans)%26+97; printf("%d ",lastans=PAM::insert(i)); } return 0; }
-
即求一个回文子串在原串中的出现次数
-
考虑每新增一个字符时它的所有回文后缀出现次数都会\(+1\),于是给它的最大回文后缀打一个\(tag\),然后按拓扑序从大到小枚举,将当前点的\(tag\)添加到\(fail\)上即可。在\(PAM\)中,节点本就是按拓扑序加入的,因此直接倒着扫一遍即可。
-
代码:
#include<bits/stdc++.h> using namespace std; const int N=3e5+10; char s[N]; int n; namespace PAM{ int fail[N],len[N],ch[N][26],siz[N]; int last,tot; inline void init(){ len[0]=0;len[1]=-1; fail[0]=fail[1]=1; last=0;tot=1; } inline void newnode(int x,int p,int v){ ch[p][v]=++tot;len[tot]=len[p]+2; if(p==1) fail[tot]=0; else{ p=fail[p]; while(s[x-len[p]-1]!=s[x]) p=fail[p]; fail[tot]=ch[p][v]; } } inline void insert(int x){ int p=last,v=s[x]-'a'; while(s[x-len[p]-1]!=s[x]) p=fail[p]; if(!ch[p][v]) newnode(x,p,v); last=ch[p][v]; siz[last]++; } inline void getsiz(){ long long ans=0; for(int i=tot;i>=2;--i){ siz[fail[i]]+=siz[i]; ans=max(ans,1ll*siz[i]*len[i]); } printf("%lld\n",ans); } } int main(){ scanf("%s",s+1); n=strlen(s+1); PAM::init(); for(int i=1;i<=n;++i) PAM::insert(i); PAM::getsiz(); return 0; }
-
-
-
对于这道题,我们引入一个\(PAM\)非常常用的一个东西:\(trans\)指针,该指针指向不超过回文串长度一半的最长回文后缀,有了它,那么这道题中我们只需要看每一个回文串的长度是否是\(4\)的倍数,且它的\(trans\)的长度恰好是它的一半即可
-
根据定义不难确定\(trans\)的维护方式:新增一个节点时,若节点长度\(\le 2\)则让\(trans=fail\).,否则,设它的父亲为\(u\),先跳到\(trans[u]\),再从\(trans[u]\)一路跳\(fail\)直到跳到一个长度不超过当前回文串一半的回文后缀即可
-
代码:
#include<bits/stdc++.h> using namespace std; const int N=5e5+10; int n,ans; char s[N]; namespace PAM{ int last,tot; int ch[N][26],fail[N],len[N]; int trans[N];//trans指向不超过串长一半的最长回文后缀 inline void init(){ len[0]=0;len[1]=-1; fail[0]=fail[1]=1; last=0;tot=1; } inline void newnode(int x,int p,int v){ ch[p][v]=++tot;len[tot]=len[p]+2; if(p==1) fail[tot]=trans[tot]=0; else{ int rec=p; p=fail[p]; while(s[x-len[p]-1]!=s[x]) p=fail[p]; fail[tot]=ch[p][v]; if(len[tot]==2) trans[tot]=fail[tot]; else{ p=trans[rec]; while(((len[p]+2)<<1)>len[tot]||s[x-len[p]-1]!=s[x]) p=fail[p]; trans[tot]=ch[p][v]; } } } inline int insert(int x){ int p=last,v=s[x]-'a'; while(s[x-len[p]-1]!=s[x]) p=fail[p]; if(!ch[p][v]) newnode(x,p,v); last=ch[p][v]; if(len[trans[last]]==len[last]/2&&len[last]%4==0) return len[last]; else return 0; } } int main(){ scanf("%d",&n); scanf("%s",s+1); PAM::init(); for(int i=1;i<=n;++i){ int x=PAM::insert(i); ans=max(ans,x); } printf("%d\n",ans); return 0; }
-
-
例题:loj#141 回文子串
这道题要求我们同时支持在字符串前后添加,于是我们初始将字符串的左右端点设为\(4e5+1\)和\(4e5\),然后对前后插入时分别维护一个\(llast\)与\(rlast\),对每个节点维护最长回文前缀与最长回文后缀,后者就是\(fail\),而因为是回文串,所有最长回文前缀就等于最长回文后缀所以直接使用\(fail\)即可,注意当插入字符后整个字符串成为回文串时,要将\(llast\)与\(rlast\)合并
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int N=4e5+10; namespace PAM{ char s[N<<1]; ll ans; int L,R,llast,rlast,tot,ch[N][26],len[N],fail[N],dep[N]; inline void init(){ len[0]=0;len[1]=-1; fail[0]=1;fail[1]=1; llast=rlast=0;tot=1; L=N,R=N-1; } inline void push_back(char c){ int v=c-'a'; s[++R]=c; int p=rlast; while(s[R-len[p]-1]!=c) p=fail[p]; if(!ch[p][v]){ ch[p][v]=++tot;len[tot]=len[p]+2; if(p==1) fail[tot]=0; else{ int now=fail[p]; while(s[R-len[now]-1]!=c) now=fail[now]; fail[tot]=ch[now][v]; } dep[tot]=dep[fail[tot]]+1; } rlast=ch[p][v]; ans+=dep[rlast]; if(R-L+1==len[rlast]) llast=rlast; } inline void push_front(char c){ int v=c-'a'; s[--L]=c; int p=llast; while(s[L+len[p]+1]!=c) p=fail[p]; if(!ch[p][v]){ ch[p][v]=++tot;len[tot]=len[p]+2; if(p==1) fail[tot]=0; else{ int now=fail[p]; while(s[L+len[now]+1]!=c) now=fail[now]; fail[tot]=ch[now][v]; } dep[tot]=dep[fail[tot]]+1; } llast=ch[p][v]; ans+=dep[llast]; if(len[llast]==R-L+1) rlast=llast; } } using PAM::ans; int q; char ch[N]; int main(){ // freopen("in.in","r",stdin); // freopen("ans.ans","w",stdout); scanf("%s",ch+1); int len=strlen(ch+1); PAM::init(); for(int i=1;i<=len;++i) PAM::push_back(ch[i]); scanf("%d",&q); while(q--){ int op; scanf("%d",&op); if(op<=2){ scanf("%s",ch+1); len=strlen(ch+1); for(int i=1;i<=len;++i){ if(op==1) PAM::push_back(ch[i]); else PAM::push_front(ch[i]); } } else printf("%lld\n",ans); } return 0; }
-
例题:洛谷P5555 秩序魔咒