回文树/回文自动机(PAM)学习笔记
回文树(也就是回文自动机)实际上是奇偶两棵树,每一个节点代表一个本质不同的回文子串(一棵树上的串长度全部是奇数,另一棵全部是偶数),原串中每一个本质不同的回文子串都在树上出现一次且仅一次。
一个节点的fail指针指向它的最长回文后缀(不包括自身,所有空fail均连向1)。归纳容易证明,当在原串末尾新增一个字符时,回文树上至多会新增一个节点,这也证明了一个串本质不同的回文子串个数不会超过n。
建树时采用增量构造法,当考虑新字符s[i]时,先找到以s[i-1]为结尾的节点p,并不断跳fail。若代表新增回文子串的节点已存在则直接结束,否则通过fail[p]不断跳fail找到新节点的fail。
0,1号节点均不代表串,常数大于manacher。初始化fail[0]=fail[1]=1,len[1]=-1,tot=1,last=0。
[BZOJ2160]拉拉队排练
建立后缀树后树上DP求出每种回文子串的出现次数即可。
1 #include<cstdio> 2 #include<algorithm> 3 #define rep(i,l,r) for (int i=(l); i<=(r); i++) 4 typedef long long ll; 5 using namespace std; 6 7 const int N=1200010,mod=19930726; 8 char s[N]; 9 ll K,sm; 10 int n,ans=1,lst,nd=1,len[N],fail[N],son[N][27],sz[N]; 11 struct P{ int l,c; }c[N]; 12 bool operator <(const P &a,const P &b){ return a.l>b.l; } 13 14 int ksm(int a,int b){ 15 int s=1; 16 for (; b; a=1ll*a*a%mod,b>>=1) 17 if (b & 1) s=1ll*s*a%mod; 18 return s; 19 } 20 21 void ext(int c,int n,char s[]){ 22 int p=lst; 23 while (s[n-len[p]-1]!=s[n]) p=fail[p]; 24 if (!son[p][c]){ 25 int np=++nd,q=fail[p]; 26 while (s[n-len[q]-1]!=s[n]) q=fail[q]; 27 len[np]=len[p]+2; fail[np]=son[q][c]; son[p][c]=np; 28 } 29 lst=son[p][c]; sz[lst]++; 30 } 31 32 int main(){ 33 freopen("bzoj2160.in","r",stdin); 34 freopen("bzoj2160.out","w",stdout); 35 scanf("%d%lld%s",&n,&K,s+1); 36 len[1]=-1; fail[1]=fail[0]=1; 37 rep(i,1,n) ext(s[i]-'a',i,s); 38 for (int i=nd; i; i--) sz[fail[i]]+=sz[i]; 39 rep(i,2,nd) c[i-1]=(P){len[i],sz[i]}; 40 sort(c+1,c+nd); 41 rep(i,1,nd-1){ 42 if (!(c[i].l&1)) continue; 43 ll t=min(K,(ll)c[i].c); ans=1ll*ans*ksm(c[i].l,t)%mod; K-=t; 44 if (!K) break; 45 } 46 printf("%d\n",K?-1:ans); 47 return 0; 48 }
[BZOJ3676][APIO2014]回文串
显然建出回文树后求出每个点的出现次数与长度即可。
1 #include<cstdio> 2 #include<cstring> 3 #include<algorithm> 4 #define rep(i,l,r) for (int i=(l); i<=(r); i++) 5 typedef long long ll; 6 using namespace std; 7 8 const int N=300010; 9 char s[N]; 10 ll ans; 11 int n,lst,nd=1,len[N],fail[N],sz[N],son[N][27]; 12 13 void ext(int c,int n,char s[]){ 14 int p=lst; 15 while (s[n-len[p]-1]!=s[n]) p=fail[p]; 16 if (!son[p][c]){ 17 int np=++nd,q=fail[p]; 18 while (s[n-len[q]-1]!=s[n]) q=fail[q]; 19 len[np]=len[p]+2; fail[np]=son[q][c]; son[p][c]=np; 20 } 21 sz[lst=son[p][c]]++; 22 } 23 24 int main(){ 25 freopen("bzoj3676.in","r",stdin); 26 freopen("bzoj3676.out","w",stdout); 27 scanf("%s",s+1); n=strlen(s+1); 28 fail[0]=fail[1]=1; len[1]=-1; s[0]=-1; 29 rep(i,1,n) ext(s[i]-'a',i,s); 30 for (int i=nd; i; i--) sz[fail[i]]+=sz[i]; 31 rep(i,2,nd) ans=max(ans,1ll*sz[i]*len[i]); 32 printf("%lld\n",ans); 33 return 0; 34 }
[CF17E]Palisection
正难则反,所有回文串对数减去不相交对数。以某个位置结尾的回文子串个数等于它在回文树上代表的节点的深度,后缀和优化一下即可。
1 #include<cstdio> 2 #include<algorithm> 3 #define rep(i,l,r) for (int i=(l); i<=(r); i++) 4 #define For(i,x) for (int i=h[x],k; i; i=nxt[i]) 5 typedef long long ll; 6 using namespace std; 7 8 const int N=2000010,mod=51123987; 9 char s[N]; 10 int n,cnt,ans,nd,lst,p1[N],p2[N],dep[N],to[N],nxt[N],val[N],h[N],fail[N],len[N]; 11 12 void add(int u,int v,int w){ to[++cnt]=v; nxt[cnt]=h[u]; val[cnt]=w; h[u]=cnt; } 13 14 void init(){ 15 rep(i,0,nd) h[i]=fail[i]=len[i]=dep[i]=0; 16 cnt=lst=0; nd=1; len[1]=-1; fail[0]=fail[1]=1; 17 } 18 19 int son(int x,int c){ For(i,x) if (val[i]==c) return k=to[i]; return 0; } 20 21 void ext(int c,int n,char s[]){ 22 int p=lst; 23 while (s[n-len[p]-1]!=s[n]) p=fail[p]; 24 if (!son(p,c)){ 25 int np=++nd,q=fail[p]; 26 while (s[n-len[q]-1]!=s[n]) q=fail[q]; 27 len[np]=len[p]+2; fail[np]=son(q,c); dep[np]=dep[fail[np]]+1; add(p,np,c); 28 } 29 lst=son(p,c); 30 } 31 32 int main(){ 33 freopen("cf17e.in","r",stdin); 34 freopen("cf17e.out","w",stdout); 35 scanf("%d%s",&n,s+1); init(); 36 rep(i,1,n) ext(s[i]-'a',i,s),p1[i]=dep[lst],ans=(ans+p1[i])%mod; 37 ans=1ll*ans*(ans-1)/2%mod; reverse(s+1,s+n+1); init(); 38 rep(i,1,n) ext(s[i]-'a',i,s),p2[n-i+1]=dep[lst]; 39 for (int i=n; i; i--) p2[i]=(p2[i]+p2[i+1])%mod; 40 rep(i,1,n) ans=(ans-1ll*p1[i]*p2[i+1]%mod+mod)%mod; 41 printf("%d\n",ans); 42 return 0; 43 }
[Aizu2292]Common Palindromes
给定S,T,询问有多少(l1,r1,l2,r2)使得S[l1,r1]回文且S[l1,r1]=T[l2,r2]。
显然对S建出回文自动机然后T在上面跑,记录每个S中回文串的出现次数以及T中有多少个子串与此串匹配。注意初始x=1(可以认为回文树的根是1),且匹配是不仅要看此节点是否有对应子节点,也要看s[i-len[x]-1]是否等于s[i]。
1 #include<cstdio> 2 #include<cstring> 3 #include<algorithm> 4 #define rep(i,l,r) for (int i=(l); i<=(r); i++) 5 typedef long long ll; 6 using namespace std; 7 8 const int N=60010; 9 char s[N]; 10 ll ans; 11 int n,lst,nd=1,f[N],son[N][27],fail[N],len[N],sz[N]; 12 13 void ext(int c,int n,char s[]){ 14 int p=lst; 15 while (s[n-len[p]-1]!=s[n]) p=fail[p]; 16 if (!son[p][c]){ 17 int np=++nd,q=fail[p]; 18 while (s[n-len[q]-1]!=s[n]) q=fail[q]; 19 len[np]=len[p]+2; fail[np]=son[q][c]; son[p][c]=np; 20 } 21 sz[lst=son[p][c]]++; 22 } 23 24 int main(){ 25 freopen("Aizu2292.in","r",stdin); 26 freopen("Aizu2292.out","w",stdout); 27 len[1]=-1; fail[0]=fail[1]=1; 28 scanf("%s",s+1); n=strlen(s+1); 29 rep(i,1,n) ext(s[i]-'A',i,s); 30 scanf("%s",s+1); n=strlen(s+1); int x=1; 31 rep(i,1,n){ 32 int c=s[i]-'A'; 33 while (x!=1 && (!son[x][c] || s[i]!=s[i-len[x]-1])) x=fail[x]; 34 if (son[x][c] && s[i]==s[i-len[x]-1]) x=son[x][c],f[x]++; 35 } 36 for (int i=nd; i; i--) f[fail[i]]+=f[i],sz[fail[i]]+=sz[i]; 37 rep(i,2,nd) ans+=1ll*f[i]*sz[i]; 38 printf("%lld\n",ans); 39 return 0; 40 }
[BZOJ2342][SHOI2011]双倍回文
就是求后半段也为回文串的回文串个数,在fail树上DFS并维护每个长度的回文串个数即可。
或者考虑求half[i]表示节点i的最深祖先满足len[half[i]]<=len[i]/2,这个同样可以在建树的时候求得。
1 #include<cstdio> 2 #include<algorithm> 3 #define rep(i,l,r) for (int i=(l); i<=(r); i++) 4 typedef long long ll; 5 using namespace std; 6 7 const int N=600010; 8 char s[N]; 9 int n,ans,lst,nd=1,fail[N],son[N][27],len[N],half[N]; 10 11 void ext(int c,int n,char s[]){ 12 int p=lst; 13 while (s[n-len[p]-1]!=s[n]) p=fail[p]; 14 if (!son[p][c]){ 15 int np=++nd,q=fail[p]; 16 while (s[n-len[q]-1]!=s[n]) q=fail[q]; 17 len[np]=len[p]+2; fail[np]=son[q][c]; son[p][c]=np; 18 if (len[np]==1) half[np]=0; 19 else{ 20 int pos=half[p]; 21 while (s[n-len[pos]-1]!=s[n] || (len[pos]+2)*2>len[np]) pos=fail[pos]; 22 half[np]=son[pos][c]; 23 } 24 } 25 lst=son[p][c]; 26 } 27 28 int main(){ 29 freopen("bzoj2342.in","r",stdin); 30 freopen("bzoj2342.out","w",stdout); 31 scanf("%d%s",&n,s+1); len[1]=-1; fail[0]=fail[1]=1; 32 rep(i,1,n) ext(s[i]-'a',i,s); 33 rep(i,2,nd) if (len[half[i]]*2==len[i] && len[i]%4==0) ans=max(ans,len[i]); 34 printf("%d\n",ans); 35 return 0; 36 }