回文自动机pam

目的:类似回文Trie树+ac自动机,可以用来统计一些其他的回文串相关的量

复杂度:O(nlogn)

https://blog.csdn.net/Lolierl/article/details/99971257

 

 

https://www.luogu.org/problem/P5496

求出以每个位置结尾的回文子串个数,强制在线

#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
const int maxn=2e6+10;
struct pam_trie
{
    int ch[26];
    int fail,len,num;
};
struct pam
{
    pam_trie b[maxn];
    int n,length,last,cnt,s[maxn];
    char c[maxn];
    pam()
    {
        b[0].len=0;b[1].len=-1;
        b[0].fail=1;b[1].fail=0;
        last=0;
        cnt=1;
    }
    void read()
    {
        scanf("%s",c+1);
        length=strlen(c+1);
    }
    int get_fail(int x)
    {
        while(s[n-b[x].len-1]!=s[n])x=b[x].fail;
        return x;
    }
    void insert()
    {
        int p=get_fail(last); 
        if(!b[p].ch[s[n]])
        {
            b[++cnt].len=b[p].len+2;
            b[cnt].fail=b[get_fail(b[p].fail)].ch[s[n]];
            //b[cnt].num=b[b[cnt].fail].num+1;
            b[p].ch[s[n]]=cnt;
        }
        last=b[p].ch[s[n]];
        b[last].num=b[b[last].fail].num+1;
    }
    void solve()
    {
        int k=0;
        s[0]=26;
        for(n=1;n<=length;n++)
        {
            c[n]=(c[n]-97+k)%26+97;
            s[n]=c[n]-'a';
            insert();
            printf("%d ",b[last].num);
            k=b[last].num;
        }
    }
}P;

int main()
{
    P.read();
    P.solve();
    return 0;
}
View Code

https://www.luogu.org/problem/P3649

求回文子串出现次数*长度的最大值

#include<iostream>
#include<stdio.h>
#include<cstdio>
#include<cstring>
using namespace std;
const int maxn=3e5+10;
struct pam_trie
{
    int ch[26];
    int fail,len,sum;
};
struct pam
{
    pam_trie b[maxn];
    int n,length,last,cnt,s[maxn];
    char c[maxn];
    long long ans;
    pam()
    {
        b[0].len=0;b[1].len=-1;
        b[0].fail=1;b[1].fail=0;
        last=0;
        cnt=1;
    }
    void read()
    {
        scanf("%s",c+1);
        length=strlen(c+1);
    }
    int get_fail(int x)
    {
        while(s[n-b[x].len-1]!=s[n])x=b[x].fail;
        return x;
    }
    void insert()
    {
        int p=get_fail(last);
        if(!b[p].ch[s[n]])
        {
            b[++cnt].len=b[p].len+2;
            b[cnt].fail=b[get_fail(b[p].fail)].ch[s[n]];
            b[p].ch[s[n]]=cnt;
        }
        last=b[p].ch[s[n]];
        b[last].sum++;
    }
    void solve()
    {
        s[0]=26;
        for(n=1;n<=length;n++)
        {
            s[n]=c[n]-'a';
            insert();
        }
        ans=0;
        for(int i=cnt;i>0;i--)
        {
            b[b[i].fail].sum+=b[i].sum;
            ans=max(ans,1ll*b[i].sum*b[i].len);
        }
        printf("%lld\n",ans);
    }
}P;
int main()
{
    P.read();
    P.solve();
}
View Code

https://www.luogu.org/problem/P4287

计算串的最长双倍回文子串的长度,tips:fail指针指向当前节点所表示的回文串的最长回文后缀

#include<stdio.h>
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
const int maxn=5e5+10;
struct pam_trie
{
    int ch[26];
    int fail,len,sum;
};
struct pam
{
    pam_trie b[maxn];
    int n,length,last,cnt,s[maxn];
    char c[maxn];
    pam()
    {
        b[0].len=0;b[1].len=-1;
        b[0].fail=1;b[1].fail=0;
        last=0;cnt=1;
    }
    void read()
    {
        scanf("%d",&length);
        scanf("%s",c+1);
    }
    int get_fail(int x)
    {
        while(s[n-b[x].len-1]!=s[n])x=b[x].fail;
        return x;
    }
    void insert()
    {
        int p=get_fail(last);
        if(!b[p].ch[s[n]])
        {
            b[++cnt].len=b[p].len+2;
            b[cnt].fail=b[get_fail(b[p].fail)].ch[s[n]];
            b[p].ch[s[n]]=cnt;
        }
        last=b[p].ch[s[n]];
        b[last].sum++;
    }
    void solve()
    {
        s[0]=26;
        for(n=1;n<=length;n++)
        {
            s[n]=c[n]-'a';
            insert();
        }
        int ans=0;
        for(int i=cnt;i>0;i--)
        {
            int pos=i;
            if(b[i].len%4!=0||b[i].len<=ans)continue;
            while(2*b[pos].len>b[i].len)pos=b[pos].fail;
            if(2*b[pos].len==b[i].len)ans=b[i].len;
        }
        printf("%d\n",ans);
    }
}P;
int main()
{
    P.read();
    P.solve();
    return 0;
}
View Code

ICPC 2018 南京 Mediocre String Problem,用回文自动机来求出以每个位置结尾的回文子串个数,再进行exkmp

#include<iostream>
#include<stdio.h>
#include<string.h>
#include<algorithm>
using namespace std;
const int maxn=1e6+10;
struct pam_trie
{
    int ch[26];
    int fail,len,num;
};
int res[maxn];
char s1[maxn],s2[maxn],t[maxn];

struct pam
{
    pam_trie b[maxn];
    int n,last,cnt,s[maxn],length;
    pam()
    {
        b[0].len=0;b[1].len=-1;
        b[0].fail=1;b[1].fail=0;
        last=0;
        cnt=1;
    }
    int get_fail(int x)
    {
        while(s[n-b[x].len-1]!=s[n])x=b[x].fail;
        return x;
    }
    int insert()
    {
        int p=get_fail(last);
        if(!b[p].ch[s[n]])
        {
            b[++cnt].len=b[p].len+2;
            b[cnt].fail=b[get_fail(b[p].fail)].ch[s[n]];
            b[p].ch[s[n]]=cnt;
        }
        last=b[p].ch[s[n]];
        b[last].num=b[b[last].fail].num+1;
        return b[last].num;
    }
    void solve()
    {
        s[0]=26;
        length=strlen(s1);
        for(n=1;n<=length;n++)
        {
            s[n]=s1[n-1]-'a';
            res[n-1]=insert();
        }
    }
}P;

int Next[maxn],extend[maxn];
void get_next(char *s)
{
    int n=strlen(s),i,j,k=1;
    for(j=0;1+j<n&&s[j]==s[1+j];j++);
    Next[1]=j;
    for(i=2;i<n;i++)
    {
        int len=k+Next[k],L=Next[i-k];
        if(L<len-i)Next[i]=L;
        else 
        {
            for(j=max(0,len-i);i+j<n&&s[j]==s[i+j];j++);
            Next[i]=j;
            k=i;
        }
    }
    Next[0]=n;    
}
void ex_kmp(char *T,char *s)
{
    int n=strlen(T),m=strlen(s),i,j,k;
    for(j=0;j<n&&j<m&&T[j]==s[j];j++);
    extend[0]=j;
    k=0;
    for(i=1;i<n;i++)
    {
        int len=k+extend[k],L=Next[i-k];
        if(L<len-i)extend[i]=L;
        else
        {
            for(j=max(0,len-i);j<m&&i+j<n&&s[j]==T[i+j];j++);
            extend[i]=j;
            k=i;
        }
    }
}


int main()
{
    scanf("%s",s1);
    scanf("%s",t);
    int lens=strlen(s1);
    
    reverse(s1,s1+lens);
    for(int i=0;i<lens;i++)s2[i]=s1[i];
    s2[lens]='\0';
    
    get_next(t);
    ex_kmp(s2,t);
    long long ans=0;
    P.solve();
    for(int i=1;i<lens;i++)
    {
        ans+=1ll*extend[i]*res[i-1]; 
    }
    printf("%lld\n",ans);    
    return 0;
} 
View Code

 

...

posted @ 2019-10-10 21:17  myrtle  阅读(135)  评论(0编辑  收藏  举报