poj2778 求构造长度为n的字符串不包含给定的m个字符串的个数(矩阵乘法+ac自动机)
题:http://poj.org/problem?id=2778
题意:给定m个模式串,问长度为n的字符串不包含这些模式串的有几种可能
分析:因为n很大,所以考虑矩阵ksm来解决,构造一个矩阵res[i][j]表示从i到j有多少种方案数,我们先考虑只走1步后的res数组的构造,i节点能走到j节点当且仅当i节点和j节点都是安全的点,这个安全的点就是用m个模式串构成的trie树上的end[],显然根结点是安全结点。 一个非根结点是危险结点的充要条件是: 它的路径字符串本身就是一个不良单词 ,或 它的路径字符串的后缀对应的结点(即fail[i])是危险结点。预处理完ac自动机后,就可以处理res数组,这个res数组就相当于在学离散数学时的矩阵;剩下的n步就交给ksm这个res数组即可,答案就是sum(res[0][i])
#include<cstdio> #include<algorithm> #include<iostream> #include<cstring> #include<queue> #include<cmath> using namespace std; typedef long long ll; const int mod=1e5; const int maxn=4;///只有4个字母 const int N=1e3+3; struct ac{ int trie[N][maxn],fail[N]; int tot,root; bool end[N]; ll res[N][N],ans[N][N],tmp[N][N]; int newnode(){ for(int i=0;i<maxn;i++){ trie[tot][i]=-1; } end[tot++]=0; return tot-1; } void init(){ tot=0; root=newnode(); memset(res,0,sizeof(res)); memset(ans,0,sizeof(ans)); memset(end,false,sizeof(end)); } int getid(char c){ if(c=='A') return 0; if(c=='C') return 1; if(c=='T') return 2; if(c=='G') return 3; } void insert(char *buf,int id){ int now=root,len=strlen(buf); for(int i=0;i<len;i++){ int x=getid(buf[i]); if(trie[now][x]==-1) trie[now][x]=newnode(); now=trie[now][x]; } end[now]=true;//它的路径字符串本身就是一个不良单词 } void getfail(){ queue<int>que; while(!que.empty()) que.pop(); fail[root]=root; for(int i=0;i<maxn;i++) if(trie[root][i]==-1) trie[root][i]=root; else{ fail[trie[root][i]]=root; que.push(trie[root][i]); } while(!que.empty()){ int now=que.front(); que.pop(); if(end[fail[now]])//它的路径字符串的后缀对应的结点(即fail[i])是危险结点 end[now]=true; for(int i=0;i<maxn;i++){ if(trie[now][i]!=-1){ fail[trie[now][i]]=trie[fail[now]][i]; que.push(trie[now][i]); } else trie[now][i]=trie[fail[now]][i]; } } } void path(){ for(int i=0;i<tot;i++){ for(int j=0;j<maxn;j++) if(!end[i]&&!end[trie[i][j]]){ // cout<<i<<"!!"<<j<<endl; res[i][trie[i][j]]++; } } } void mul(ll a[][N],ll b[][N]){ for(int i=0;i<tot;i++) for(int j=0;j<tot;j++){ tmp[i][j]=0; for(int k=0;k<tot;k++) tmp[i][j]=(tmp[i][j]+a[i][k]*b[k][j])%mod; } for(int i=0;i<tot;i++) for(int j=0;j<tot;j++) a[i][j]=tmp[i][j]; } ll solve(int n){ for(int i=0;i<tot;i++) ans[i][i]=1; while(n){ if(n&1){ mul(ans,res); } n>>=1; mul(res,res); } ll ANS=0; for(int i=0;i<tot;i++) ANS=(ANS+ans[0][i])%mod; return ANS; } }AC; char s[110]; int main(){ int m,n; scanf("%d%d",&m,&n); AC.init(); for(int i=1;i<=m;i++){ scanf("%s",s); AC.insert(s,i); } AC.getfail(); AC.path(); printf("%lld\n",AC.solve(n)); return 0; }