POJ2778 DNA Sequence [AC自动机+矩阵]
又是一道调了大半天的题,最后发现竟然是自己建立trie图的地方有个小BUG,这个小BUG在字符串匹配时没什么影响,所以一直没发现出来。刚刚学习,还是理解的不够深入啊。现在这个trie图应该算是写的很简洁了,可以拿来当模版了。
题意很简单,就是问长度为n不包含若干子串的串一共有多少个。这里可以用AC自动机DP,首先对于单词的结尾节点,标记为非法节点,一旦走到了一个非法节点,就说明包含了某个单词。网上的解题报告很多人说是AC自动机DP。但这题我的做法似乎没有用到DP,只是用矩阵加速了一下罢了。。。。首先标记出非法节点,补全trie图,用一个矩阵表示从每个合法节点到其它合法节点转移的方案,可以表示为一个邻接矩阵M,求第N步从根节点到其它节点有多少方案,这个问题就是很简单的路径方案数问题,只要求M^N,然后求sum(M[1][1]...[1][N])就行了。。。。
需要注意的是,在建立trie图的过程中,如果一个节点的失败指针指向了另一个非法节点,则说明这个节点也是个非法节点。比如对于字符串{AG,CAGT},第一个字符串中的G和第二个字符串中的T很显然是非法节点。但是在第二个字符串中走到G时,实际上已经包含了AG这个字符串,它的失败指针指向AG中的G,所以这个G也是一个非法节点。
调了大半天终于AC了,稍加优化的程序跑了32ms,居然刷进了第一版。。0ms排在第一的竟然是ZFY学长,YM。。。
#include <string.h> #include <stdio.h> #include <queue> #define MAXN 110 #define MOD 100000 typedef long long LL; LL dmat[MAXN][MAXN]; struct matrix{ LL mz[MAXN][MAXN];int n; #define FOR(i) for(int i=1;i<=n;i++) //初始化矩阵,空矩阵,单位矩阵和dmat矩阵 matrix(int nn,int type):n(nn){ if(type==0)FOR(i)FOR(j)mz[i][j]=0; else if(type==1)FOR(i)FOR(j)mz[i][j]=(i==j)?1:0; else FOR(i)FOR(j)mz[i][j]=dmat[i][j]; } //重载矩阵乘法,10^5*10^5*100不会超longlong的,最后一次性模就可以了,模是很费时的 matrix operator *(const matrix& b)const{ matrix ans(n,0); FOR(i)FOR(j)if(mz[i][j]) FOR(k)ans.mz[i][k]+=mz[i][j]*b.mz[j][k]; FOR(i)FOR(j)if(ans.mz[i][j]>MOD)ans.mz[i][j]%=MOD; return ans; } //二分矩阵乘法 matrix binMat(int x){ matrix ans(n,1),tmp(n,2); for(;x;tmp=tmp*tmp,x>>=1){ if(x&1)ans=ans*tmp; } return ans; } }; int n,m; char s[12]; int next[MAXN][4],fail[MAXN],flag[MAXN],id[MAXN],ids,pos; int trans(char c){ if(c=='A')return 0; if(c=='C')return 1; if(c=='T')return 2; return 3; } int newnode(){ for(int i=0;i<4;i++)next[pos][i]=0; fail[pos]=flag[pos]=id[pos]=0; return pos++; } void insert(char *s){ int p=0,len=strlen(s); for(int i=0;i<len;i++){ int &x=next[p][trans(s[i])]; p=x?x:x=newnode(); } flag[p]=1; } int q[MAXN],front,rear; void makenext(){ q[front=rear=0]=0,rear++; while(front<rear){ int u=q[front++]; for(int i=0;i<4;i++){ int v=next[u][i]; if(flag[v])continue; if(v==0)next[u][i]=next[fail[u]][i]; else q[rear++]=v; //这个地方忘了判断v是否是0了,调了很久...省代码还是要小心啊.. if(v&&u){ fail[v]=next[fail[u]][i]; //如果指向一个非法节点,那这个节点也是一个非法节点(比如cg和acgt这样的串,第二个串中的g也是非法的) if(flag[fail[v]])flag[v]=1; } } } } int main(){ while(scanf("%d%d",&m,&n)!=EOF){ pos=ids=0;newnode(); memset(dmat,0,sizeof dmat); for(int i=0;i<m;i++){ scanf("%s",s); insert(s); } makenext(); //建立矩阵,从每个合法节点到另一个节点转移的方案数,类似于邻接矩阵 for(int u=0;u<pos;u++){ if(flag[u])continue; for(int i=0;i<4;i++){ int v=next[u][i]; if(flag[v])continue; if(id[u+1]==0)id[u+1]=++ids; if(id[v+1]==0)id[v+1]=++ids; dmat[id[u+1]][id[v+1]]++; } } matrix mt=matrix(ids,2).binMat(n); LL ans=0; for(int i=1;i<=mt.n;i++) ans+=mt.mz[1][i]; ans%=MOD; printf("%lld\n",ans); } return 0; }