POJ2778 DNA Sequence Trie+矩阵乘法

题意:给定N个有A C G T组成的字符串,求长度为L的仅由A C G T组成的字符串中有多少个是不含给定的N个字符串的题解:

首先我们把所有的模式串(给定的DNA序列)建Trie,假定我们有一个匹配串,并且在匹配过程到S[i]这个字符时匹配到了Trie上的某个节点t,那么有两种可能:

匹配失败:t->child[S[i]]为空,跳转到t->fail,因此t->fail一定不能是某个模式串的结尾;

匹配成功:跳转到t->child[S[i+1]],因此t->child[S[i+1]]一定不能是某个模式串的结尾。

另外还有一个性质:如果t->fail是某个模式串的结尾,那么t也被视作某个模式串的结尾(t->fail这个模式串为当前匹配到的位置的子串)

由于模式串最多有10个,每个模式串最长就是10,因此Trie上最多有100个节点,所以我们将Trie上的每个节点编号,构造初始矩阵a[i][j]为标号i走一步到标号j的方案数(注意Tire上的边都是有向的),能不能走由上面的规则决定。

因此问题转变为从根节点出发走n步不经过某个模式串的结尾的节点到达任意节点的方案总数,矩阵乘法随便做。

#include <map>
#include <queue>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
using namespace std;
#define ll long long

const int P=100000;
const int MAXK=4;
const int MAXN=100+2;
struct Trie{
    int mark;
    Trie *child[MAXK],*fail;
    bool flag;
}*root,*mark[MAXN];
int N,M,cnt;
ll a[MAXN][MAXN],b[MAXN][MAXN],t[MAXN][MAXN],ans;
char S[MAXN];
queue<Trie *> q;
map<char,int> m;

Trie *NewNode(int k){
    Trie *x=new Trie;
    memset(x,0,sizeof(Trie));
    x->mark=k,mark[k]=x;
    return x;
}

void Insert(Trie *&x,char *S){
    Trie *p=x;
    for(int i=0;S[i];i++){
        if(!p->child[m[S[i]]]) p->child[m[S[i]]]=NewNode(++cnt);
        p=p->child[m[S[i]]];
    }
    p->flag=1;
}

void Get_Matrix(Trie *&x){
    for(int i=0;i<MAXK;i++)
        if(x->child[i]) x->child[i]->fail=root,q.push(x->child[i]);
        else x->child[i]=x;

    Trie *p,*t;
    while(!q.empty()){
        t=q.front(),q.pop();
        if(t->fail->flag) t->flag=1;
        for(int i=0;i<MAXK;i++)
            if(t->child[i]){
                t->child[i]->fail=t->fail->child[i];
                q.push(t->child[i]);
            }
            else t->child[i]=t->fail->child[i];
    }

    for(int i=1;i<=cnt;i++)
        for(int j=0;j<MAXK;j++)
            if(!mark[i]->flag && !mark[i]->child[j]->flag) a[mark[i]->mark][mark[i]->child[j]->mark]++;
}

void Matrix_Copy(ll a[MAXN][MAXN],ll b[MAXN][MAXN]){
    for(int i=1;i<=cnt;i++)
        for(int j=1;j<=cnt;j++)
            a[i][j]=b[i][j];
}

void Matrix_Mul(ll a[MAXN][MAXN],ll b[MAXN][MAXN]){
    memset(t,0,sizeof(t));
    for(int i=1;i<=cnt;i++)
        for(int j=1;j<=cnt;j++)
            for(int k=1;k<=cnt;k++)
                t[i][j]=(t[i][j]+a[i][k]*b[k][j])%P;
    Matrix_Copy(a,t);
}

void Quick_Pow(ll x[MAXN][MAXN],int y,ll ans[MAXN][MAXN]){
    if(y&1) Matrix_Copy(ans,x);
    else
        for(int i=1;i<=cnt;i++) ans[i][i]=1;

    while(y>>=1){
        Matrix_Mul(x,x);
        if(y&1) Matrix_Mul(ans,x);
    }
}

int main(){
    m['A']=0,m['G']=1,m['C']=2,m['T']=3;
    root=NewNode(++cnt);

    cin >> M >> N;
    for(int i=1;i<=M;i++){
        cin >> S;
        Insert(root,S);
    }
    Get_Matrix(root);

    Quick_Pow(a,N,b);
    for(int i=1;i<=cnt;i++) ans=(ans+b[1][i])%P;

    cout << ans << endl;

    return 0;
}
View Code

 

posted @ 2017-02-27 22:28  WDZRMPCBIT  阅读(143)  评论(0编辑  收藏  举报