AC自动机
对于多串匹配一种能够理论上时间复杂度为O(n+m)的多串匹配方式,但是时间复杂度并不稳定。
原理:在trie树上建立类似KMP的next指针的东西,也就是AC自动机的fail指针,在每次匹配的时候,不停的跳fail指针直到根节点。这是最裸的实现,但是许多情况下,这种最朴素的实现方式过不去...因为这样它的时间复杂度很不稳定,可能被卡到O(nm)这种模式下,我们就需要将AC自动机的BFS序拿出来,单独进行操作,甚至将fail指针建成fail树实现,还有什么打上差分标记之类的实现方式降低时间复杂度。
例题:
BZOJ3940: [Usaco2015 Feb]Censoring
分析:
AC自动机裸题,同年同月银组题是KMP+栈实现,而这道题是AC自动机+栈实现,因为题目满足模式串互不包含,所以直接贪心+栈模拟维护一下即可。
附上代码:
#include <cstdio> #include <algorithm> #include <cmath> #include <cstdlib> #include <cstring> #include <queue> #include <iostream> using namespace std; #define N 1000005 struct Aho { int ch[N][26],last[N],fail[N],cnt,rot; int new_node(){memset(ch[cnt],-1,sizeof(ch[cnt]));last[cnt++]=0;return cnt-1;} void init(){cnt=0;rot=new_node();} void insert(char *s,int x) { int len=strlen(s),rt=0; for(int i=0;i<len;i++) { if(ch[rt][s[i]-'a']==-1)ch[rt][s[i]-'a']=new_node(); rt=ch[rt][s[i]-'a']; } last[rt]=x; } void get_fail() { queue <int>q;fail[0]=0; for(int i=0;i<26;i++) { if(ch[0][i]==-1)ch[0][i]=0; else fail[ch[0][i]]=0,q.push(ch[0][i]); } while(!q.empty()) { int x=q.front();q.pop(); for(int i=0;i<26;i++) { if(ch[x][i]==-1)ch[x][i]=ch[fail[x]][i]; else fail[ch[x][i]]=ch[fail[x]][i],q.push(ch[x][i]); } } } }ac; int n,m,len[N],sta[N],top,pos[N]; char sub[N],str[N]; int main() { scanf("%s%d",str+1,&n);m=strlen(str+1);ac.init(); for(int i=1;i<=n;i++)scanf("%s",sub),ac.insert(sub,i),len[i]=strlen(sub); ac.get_fail();int rt=0; for(int i=1;i<=m;i++) { rt=ac.ch[rt][str[i]-'a'];pos[++top]=rt,sta[top]=str[i]; if(ac.last[rt])top-=len[ac.last[rt]],rt=pos[top]; } for(int i=1;i<=top;i++)printf("%c",sta[i]);puts(""); return 0; }
BZOJ3172: [Tjoi2013]单词
分析:
这道题,最裸的AC自动机实现方式过不了...不用测试了,我试过了...
将AC自动机建起来,在每个节点打一个标记,每次插入的时候遍历到这个节点,这个节点的出现次数就++,之后每次讲节点的出现次数传递给fail节点,最后统计每一个串的终止节点出现次数即可。
附上代码:
#include <cstdio> #include <algorithm> #include <queue> #include <cstring> #include <cstdlib> #include <cmath> #include <iostream> #include <set> using namespace std; #define N 1000005 int ans[N]; struct Aho { int ch[N][26],pos[N],fail[N],vis[N],cnt,rot,que[N],fa[N],l,r; int new_node(){memset(ch[cnt],-1,sizeof(ch[cnt]));vis[cnt++]=0;return cnt-1;} void init(){l=r=cnt=0,rot=new_node();} void insert(char *s,int x) { int len=strlen(s),rt=rot; for(int i=0;i<len;i++) { int t=s[i]-'a'; if(ch[rt][t]==-1)ch[rt][t]=new_node(); rt=ch[rt][t];vis[rt]++; } pos[x]=rt; } void get_fail() { for(int i=0;i<26;i++) { if(ch[0][i]==-1)ch[0][i]=0; else que[r++]=ch[0][i],fail[ch[0][i]]=0; } while(l<r) { int x=que[l++]; for(int i=0;i<26;i++) { if(ch[x][i]==-1)ch[x][i]=ch[fail[x]][i]; else fail[ch[x][i]]=ch[fail[x]][i],que[r++]=ch[x][i]; } } } int match() { for(int i=cnt;i>=0;i--) { int x=que[i]; vis[fail[x]]+=vis[x]; } } }ac; char s[N]; int main() { int n; scanf("%d",&n);ac.init(); for(int i=1;i<=n;i++) { scanf("%s",s); ac.insert(s,i); } ac.get_fail();ac.match(); for(int i=1;i<=n;i++) { //printf("%d\n",ac.pos[i]); printf("%d\n",ac.vis[ac.pos[i]]); } }
BZOJ1030: [JSOI2007]文本生成器
分析:
这道题看起来和GT考试很像,用全部的生成数量去掉完全不可读的数量,就是答案。全部的数量是26^m,而完全不可读的是类似GT考试的求法,建立AC自动机,之后类似遍历AC自动机的方式进行DP,顺便注意一点,即:如果跳fail指针能跳到某一个字符串的终止节点,那么这个节点就不能作为答案出现,即:我们统计所有的不存在一个终止节点作为祖先的节点。
附上代码:
#include <cstdio> #include <cmath> #include <algorithm> #include <iostream> #include <queue> #include <cstdlib> #include <cstring> using namespace std; #define N 200005 #define mod 10007 int f[120][N],n,m;char s[N]; struct Aho { int ch[N][26],fail[N],last[N],cnt,q[N],l,r,rot; int new_node(){memset(ch[cnt],-1,sizeof(ch[cnt]));last[cnt++]=0;return cnt-1;} void init(){cnt=0;rot=new_node();} void insert(char *s) { int len=strlen(s),rt=rot; for(int i=0;i<len;i++) { int t=s[i]-'A'; if(ch[rt][t]==-1)ch[rt][t]=new_node(); rt=ch[rt][t]; } last[rt]=1; } void get_fail() { l=r=0; for(int i=0;i<26;i++) { if(ch[0][i]==-1)ch[0][i]=rot; else q[r++]=ch[0][i],fail[ch[0][i]]=rot; last[0]|=last[fail[0]]; } while(l<r) { int x=q[l++]; for(int i=0;i<26;i++) { if(ch[x][i]==-1)ch[x][i]=ch[fail[x]][i]; else q[r++]=ch[x][i],fail[ch[x][i]]=ch[fail[x]][i]; } last[x]|=last[fail[x]]; } } int match() { f[0][0]=1; for(int i=1;i<=m;i++) for(int j=0;j<cnt;j++) if(!last[j]&&f[i-1][j]) for(int k=0;k<26;k++) f[i][ch[j][k]]=(f[i][ch[j][k]]+f[i-1][j])%mod; int num=1,ans=0; for(int i=1;i<=m;i++)num=num*26%mod; for(int i=0;i<cnt;i++) { if(!last[i])ans=(ans+f[m][i])%mod; } return (num-ans+mod)%mod; } }ac; int main() { scanf("%d%d",&n,&m);ac.init(); for(int i=1;i<=n;i++) { scanf("%s",s);ac.insert(s); } ac.get_fail(); printf("%d\n",ac.match()); return 0; }
BZOJ2553: [BeiJing2011]禁忌
分析:
看到len<=10^9就知道和矩阵乘法有关。建立矩阵,如果一个节点是终止节点(或者它的祖先存在终止节点),那么将矩阵i,0变成抽到对应字符的概率,和i,cnt也变成对应概率,如果不是终止节点,那么将i和对应子节点的矩阵改成概率即可。之后矩阵乘法实现一下即可。
附上代码:
#include <cstdio> #include <cmath> #include <algorithm> #include <iostream> #include <queue> #include <cstdlib> #include <cstring> using namespace std; #define N 2005 #define mod 10007 int n,m,alpha,cnt;char s[N]; struct node { long double a[100][100]; friend node operator*(const node &a,const node &b) { node c;memset(c.a,0,sizeof(c.a)); for(int i=0;i<=cnt;i++) { for(int j=0;j<=cnt;j++) { for(int k=0;k<=cnt;k++) { c.a[i][j]=(c.a[i][j]+a.a[i][k]*b.a[k][j]); } } } return c; } void print() { for(int i=0;i<=cnt;i++) { for(int j=0;j<=cnt;j++) { printf("%.5lf ",a[i][j]); } puts(""); } } }ret,map; double q_pow(int n) { for(int i=0;i<=cnt;i++)ret.a[i][i]=1; while(n) { if(n&1)ret=ret*map; map=map*map;n=n>>1; } return ret.a[0][cnt]; } struct Aho { int ch[N][26],fail[N],last[N],q[N],l,r,rot; int new_node(){memset(ch[cnt],-1,sizeof(ch[cnt]));last[cnt++]=0;return cnt-1;} void init(){cnt=0;rot=new_node();} void insert(char *s) { int len=strlen(s),rt=rot; for(int i=0;i<len;i++) { int t=s[i]-'a'; if(ch[rt][t]==-1)ch[rt][t]=new_node(); rt=ch[rt][t]; } //printf("%d",rt); last[rt]=1; } void get_fail() { l=r=0; for(int i=0;i<alpha;i++) { if(ch[0][i]==-1)ch[0][i]=rot; else q[r++]=ch[0][i],fail[ch[0][i]]=rot; last[0]|=last[fail[0]]; } while(l<r) { int x=q[l++]; for(int i=0;i<alpha;i++) { if(ch[x][i]==-1)ch[x][i]=ch[fail[x]][i]; else q[r++]=ch[x][i],fail[ch[x][i]]=ch[fail[x]][i]; } last[x]|=last[fail[x]]; } } void match() { long double addv=1.0/(1.0*alpha); for(int i=0;i<cnt;i++) { for(int j=0;j<alpha;j++) { if(last[ch[i][j]])map.a[i][0]+=addv,map.a[i][cnt]+=addv; else map.a[i][ch[i][j]]+=addv; } } map.a[cnt][cnt]=1; } }ac; int main() { scanf("%d%d%d",&n,&m,&alpha);ac.init(); for(int i=1;i<=n;i++) { scanf("%s",s);ac.insert(s); } ac.get_fail();ac.match();//map.print(); printf("%.10f\n",q_pow(m)); return 0; }
BZOJ2938: [Poi2000]病毒
分析:
如果在trie树上存在一个环,那么就一定可以出现无穷的情况。判一下环是否存在即可。
附上代码:
#include <cstdio> #include <cmath> #include <algorithm> #include <iostream> #include <queue> #include <cstdlib> #include <cstring> using namespace std; #define N 30005 struct Aho { int ch[N][2],fail[N],last[N],cnt,rot,vis[N],inq[N]; int new_node(){ch[cnt][1]=ch[cnt][0]=-1;last[cnt++]=0;return cnt-1;} void init(){cnt=0;rot=new_node();} void insert(char *s) { int len=strlen(s),rt=rot; for(int i=0;i<len;i++) { int t=s[i]-'0'; if(ch[rt][t]==-1)ch[rt][t]=new_node(); rt=ch[rt][t]; } last[rt]=1; } void get_fail() { queue <int>q; for(int i=0;i<2;i++) { if(ch[0][i]==-1)ch[0][i]=0; else fail[ch[0][i]]=0,q.push(ch[0][i]); } while(!q.empty()) { int x=q.front();q.pop(); for(int i=0;i<2;i++) { if(ch[x][i]==-1)ch[x][i]=ch[fail[x]][i]; else fail[ch[x][i]]=ch[fail[x]][i],q.push(ch[x][i]); } last[x]|=last[fail[x]]; } } int match(int x) { inq[x]=vis[x]=1; for(int i=0;i<2;i++) { int t=ch[x][i]; if(inq[t]||((!last[t])&&(!vis[t])&&match(t)))return 1; } inq[x]=0; return 0; } }ac; char s[N]; int main() { int n;scanf("%d",&n);ac.init(); for(int i=1;i<=n;i++) { scanf("%s",s);ac.insert(s); } ac.get_fail(); if(ac.match(ac.rot))puts("TAK"); else puts("NIE"); return 0; }
BZOJ2434: [Noi2011]阿狸的打字机
分析:
先将给你的串建立成AC自动机,之后单独拎出fail树,求出每个节点在fail树上的入栈出栈序维护出来,之后再遍历一遍所有节点,将每个节点对应的询问求出即可。而正确性是因为fail树的每个节点的父节点的串都是这个节点的串的后缀。
附上代码:
#include <cstdio> #include <cmath> #include <algorithm> #include <iostream> #include <queue> #include <cstdlib> #include <cstring> using namespace std; #define N 100005 struct node { int to,next,val; }ask[N],e[N]; int head[N],head_ask[N],cnt2,cnt1,in1[N],tims,out1[N],sum[N<<1],Q,flg[N],ans[N];char s[N]; void add(int x,int y){e[cnt2].to=y;e[cnt2].next=head[x];head[x]=cnt2++;} void add_ask(int x,int y,int z){ask[cnt1].to=y;ask[cnt1].next=head_ask[x];ask[cnt1].val=z;head_ask[x]=cnt1++;} void fix(int x,int c){for(;x<=tims;x+=x&-x)sum[x]+=c;} int find(int x){int ret=0;for(;x;x-=x&-x)ret+=sum[x];return ret;} struct Aho { int ch[N][26],fail[N],q[N],last[N],rot,cnt,fa[N]; int new_node(){memset(ch[cnt],-1,sizeof(ch[cnt]));last[cnt++]=0;return cnt-1;} void init(){cnt=0,rot=new_node();} void insert(char *s) { int len=strlen(s),rt=rot,tot=0; for(int i=0;i<len;i++) { if(s[i]=='B')rt=fa[rt]; else if(s[i]=='P')last[rt]++,flg[++tot]=rt; else { if(ch[rt][s[i]-'a']==-1)ch[rt][s[i]-'a']=new_node(),fa[cnt-1]=rt; rt=ch[rt][s[i]-'a']; } } } void get_fail() { int l=0,r=0; for(int i=0;i<26;i++) { if(ch[0][i]==-1)ch[0][i]=0; else fail[ch[0][i]]=0,q[r++]=ch[0][i]; } while(l<r) { int x=q[l++]; for(int i=0;i<26;i++) { if(ch[x][i]==-1)ch[x][i]=ch[fail[x]][i]; else fail[ch[x][i]]=ch[fail[x]][i],q[r++]=ch[x][i]; } } for(int i=1;i<cnt;i++)add(fail[i],i);//printf("%d %d\n",fail[i],i); } void solve(char *s) { for(int i=0,len=strlen(s),rt=0;i<len;i++) { if(s[i]=='B')fix(out1[rt],-1),rt=fa[rt]; else if(s[i]=='P') { for(int j=head_ask[rt];j!=-1;j=ask[j].next) { int to1=ask[j].to,v=ask[j].val; ans[v]=find(out1[to1])-find(in1[to1]-1); } }else { rt=ch[rt][s[i]-'a']; fix(in1[rt],1); } } } }ac; void dfs(int x) { in1[x]=++tims; for(int i=head[x];i!=-1;i=e[i].next) { dfs(e[i].to); } out1[x]=++tims; } int main() { scanf("%s%d",s,&Q);memset(head,-1,sizeof(head));memset(head_ask,-1,sizeof(head_ask)); ac.init();ac.insert(s);ac.get_fail();dfs(0); for(int i=1;i<=Q;i++){int x,y;scanf("%d%d",&x,&y);add_ask(flg[y],flg[x],i);}ac.solve(s); for(int i=1;i<=Q;i++)printf("%d\n",ans[i]);return 0;