BZOJ 2553 AC自动机+矩阵快速幂 (神题)

思路:
我们先对所有读进来的T建一个AC自动机
因为走到一个禁忌串就需要回到根
所以呢 搞出来所有的结束点 或一下 fail指针指向的那个点

然后我们就想转移
a[i][j]表示从i节点转移到j节点的概率 如果能够转移到 ans+=1÷alphabet
这里有一个trick
建一个size+1节点 如果回到了根 就连到size+1 a[size+1][size+1]=1
这样就成了累加和了

因为长度最大有10^9,显然直接DP会无论空间还是时间都会爆炸。。。
所以用矩阵乘法+快速幂加速转移
现在考虑怎么处理出初始的转移矩阵
先算出a[i][j]表示i一步到j的概率
用bfs就可以实现,如果j是i的儿子,那么a[i][j]+=1/字符集大小
为了方便我们新建一个节点n=cnt(总结点数)+1
每次转移root时也转移到它
那么a[i][n]就是i走一步匹配到禁忌串的概率。
要把所有步都累加出来,把a[n][n]赋为1就可以了
因为这样下一次计算时b[root][n]=….+b[root][n]*a[n][n]+….
就可以把上次的答案都累加起来了。
自乘x次后,因为贡献永远是1,所以a[root][n]就表示root走x步遇到禁忌串的期望,也就是答案。
http://blog.csdn.net/thy_asdf/article/details/47087113

//By SiriusRen
#include <cstdio>
#include <cstring>
#include <algorithm>
#define N 105
#define M 26
using namespace std;
int n,num,len,alphabet,size;
char a[N];
struct matrix{long double a[N][N];void clear(){memset(a,0,sizeof(a));}}st,ans;
matrix operator * (matrix a,matrix b){
    matrix c;c.clear();
    for(int i=0;i<=n;i++)
        for(int j=0;j<=n;j++)
            for(int k=0;k<=n;k++)
                c.a[i][j]+=a.a[i][k]*b.a[k][j];
    return c;
}
struct AC_Automata{
    int ch[N][M],end[N],q[N*M],head,tail,f[N],vis[N];
    void insert(char *s,int num){
        int u=0;
        for(int i=0;s[i];i++){
            int v=s[i]-'a';
            if(!ch[u][v])ch[u][v]=++size;
            u=ch[u][v];
        }end[u]=1;
    }
    void build(){
        f[0]=100;
        while(head<=tail){
            int r=q[head++];
            for(int i=0;i<alphabet;i++){
                int u=ch[r][i];
                if(!u)ch[r][i]=ch[f[r]][i];
                else q[++tail]=u,f[u]=ch[f[r]][i];
            }
            end[r]|=end[f[r]];
        }
        head=tail=q[0]=0,vis[0]=1;
        long double base=1.0/alphabet;
        while(head<=tail){
            int r=q[head++];
            for(int i=0;i<alphabet;i++){
                if(!vis[ch[r][i]])vis[ch[r][i]]=1,q[++tail]=ch[r][i];
                if(end[ch[r][i]])st.a[r][n]+=base,st.a[r][0]+=base;
                else st.a[r][ch[r][i]]+=base;
            }
        }
    }
}ac;
void pow(){for(;len;len>>=1,st=st*st)if(len&1)ans=ans*st;}
int main(){
    scanf("%d%d%d",&num,&len,&alphabet);
    for(int i=1;i<=num;i++)scanf("%s",a),ac.insert(a,i);
    n=size+1;ac.build(),st.a[n][n]=1;
    for(int i=0;i<=n;i++)ans.a[i][i]=1;
    pow();
    printf("%Lf\n",ans.a[0][n]);
}

这里写图片描述

posted @ 2016-12-09 16:14  SiriusRen  阅读(177)  评论(0编辑  收藏  举报