AC自动机+dp(记忆化搜素)

https://vjudge.net/problem/UVA-11468

思路:构造出AC自动机后,把所有单词节点标记为禁止,就转化为从0节点走L步不进入任何禁止节点的概率。令dp[i][j]等于在i节点还要走j步不碰到禁忌节点的概率。

#include<iostream>
#include<iomanip>
#include<cstring>
#include<sstream>
#include<algorithm>
#include<cstdio>
#include<cmath>
#include<stdlib.h>
#include<string>
#include<queue>
#include<vector>
#include<map>
#include<stack>
//#include<bits/stdc++.h>
#define _for(i,a,b) for(int i=a;i<=b;i++)
using namespace std;
typedef long long ll;
const int mod =1e6+7;
double esp=1e-6;
int INF =0x3f3f3f3f;
const int inf = 1<<28;
const int MAXN=3e5+5;
int ch[500][70];
int tot,k,n,m,fail[500];
bool en[500],vis[500][110];
double dp[500][110],p[100];
map<char,int> mp;
char s[30][30];
void insertword(char *s)
{
    int u=0;
    int len=strlen(s);
    for(int i=0;i<len;i++)
    {
        int v=mp[s[i]];
        if(!ch[u][v])
        {
            ch[u][v]=++tot;
        }
        u=ch[u][v];
    }
    en[u]=1;
}
void getfail()
{
    queue<int> q;
    for(int i=0;i<m;i++)
    {
        if(ch[0][i])
        {
            q.push(ch[0][i]);
        }
    }
    while(!q.empty())
    {
        int u=q.front();q.pop();
        for(int i=0;i<m;i++)
        {
            if(ch[u][i])
            {
                fail[ch[u][i]]=ch[fail[u]][i];
                en[ch[u][i]]|=en[fail[ch[u][i]]];//注意更新不能走的位置
                //因为如果选这个点,而这个点的失配指针指向的位置en是1
                //说明在当前单词位置向前若干个字母会形成和另外单词一样的单词(包含关系)
                q.push(ch[u][i]);
                //printf("%d %d %d\n",ch[u][i],en[ch[u][i]],en[fail[ch[u][i]]]);
            }
            else
                {ch[u][i]=ch[fail[u]][i];}
        }
    }
}
double solve(int u,int l)
{
    if(!l)return 1.0;
    if(vis[u][l])return dp[u][l];
    vis[u][l]=1;
    double &ans=dp[u][l];
    ans=0.0;
    for(int i=0;i<m;i++)
    {
        if(!en[ch[u][i]])ans+=p[i]*solve(ch[u][i],l-1);
    }
    return ans;
}
int main()
{
    int t;
    scanf("%d",&t);
    for(int i=1;i<=t;i++)
    {
        tot=0;
        memset(en,0,sizeof(en));
        memset(vis,0,sizeof(vis));
        memset(dp,0,sizeof(dp));
        memset(fail,0,sizeof(fail));
        memset(ch,0,sizeof(ch));
        mp.clear();
        scanf("%d",&k);
        for(int j=1;j<=k;j++)
        {
            scanf("%s",&s[j]);
        }
        scanf("%d",&m);
        for(int j=0;j<m;j++)
        {
            char ss[2];
            scanf("%s",&ss);
            scanf("%lf",&p[j]);
            mp[ss[0]]=j;
        }
        for(int j=1;j<=k;j++)
            insertword(s[j]);
        getfail();
        scanf("%d",&n);
        printf("Case #%d: %0.6lf\n",i,solve(0,n));
    }
    return 0;
}