【BZOJ】4861: [Beijing2017]魔法咒语 AC自动机+DP+矩阵快速幂

【题意】给定n个原串和m个禁忌串,要求用原串集合能拼出的不含禁忌串且长度为L的串的数量。(60%)n,m<=50,L<=100。(40%)原串长度为1或2,L<=10^18。

【算法】AC自动机+DP+矩阵快速幂

【题解】其实题意的数据范围不太清晰,反正开200个点就足够了。

因为要匹配禁忌串,所以对禁忌串集合建立AC自动机,标记禁忌串结尾节点,以及下传到所有能fail到的点(这些点访问到都相当于匹配了禁忌串)。

令f[i][j]表示匹配到节点i,长度为j的串的数量,先预处理a[i][j]表示节点 i 匹配第 j 个原串到达的节点编号,那么就有:

f [ a[i][j] ] [ L+size[j] ] += f [ i ] [ L ]

以上就是60%数据的做法,对于40%的数据使用矩阵快速幂。

假设原串长度均为1,那么DP的转移如下:

$$f[i][L]=\sum_{j}f[j][L-1]\ \ ,\ \ j \rightarrow i$$

这很容易用一个长度为第一维大小(AC自动机节点数)的矩阵维护转移,第L个列向量就是f[i][L]。

如果原串长度有2,那么再记录L-1即可。

#include<cstdio>
#include<cstring>
#include<queue>
#include<algorithm>
#define ll long long
using namespace std;
const int maxn=5010,MOD=1e9+7;
int n,m,a[maxn][110],ch[maxn][27],val[maxn],size[maxn],sz=0,fail[maxn];
ll L;
char s[110][maxn],S[maxn];
queue<int>Q;
void insert(char *s){
    int n=strlen(s),u=0;
    for(int i=0;i<n;i++){
        int c=s[i]-'a';
        if(!ch[u][c])ch[u][c]=++sz;
        u=ch[u][c];
    }
    val[u]++;
}
void AC_build(){
    for(int c=0;c<26;c++)if(ch[0][c])Q.push(ch[0][c]);
    while(!Q.empty()){
        int u=Q.front();Q.pop();
        for(int c=0;c<26;c++)if(ch[u][c]){
            fail[ch[u][c]]=ch[fail[u]][c];
            Q.push(ch[u][c]);
            val[ch[u][c]]|=val[fail[ch[u][c]]];//
        }
        else ch[u][c]=ch[fail[u]][c];
    }
}
int M(int x){return x>=MOD?x-MOD:x;}
namespace Task1{
    int f[maxn][110];
    void solve(){
        f[0][0]=1;
        for(int l=0;l<L;l++){//
            for(int i=0;i<=sz;i++)if(f[i][l]){
                for(int j=1;j<=n;j++)if(~a[i][j]&&l+size[j]<=L){
                    f[a[i][j]][l+size[j]]=M(f[a[i][j]][l+size[j]]+f[i][l]);
                }
            }
        }
        int ans=0;
        for(int i=0;i<=sz;i++)if(f[i][L]&&!val[i])ans=M(ans+f[i][L]);
        printf("%d",ans);
    }
}
namespace Task2{
    const int maxn=110;
    int N,A[maxn*2][maxn*2],ANS[maxn*2][maxn*2],c[maxn*2][maxn*2];
    void mul(int a[maxn*2][maxn*2],int b[maxn*2][maxn*2]){
        for(int i=0;i<=N;i++){
            for(int j=0;j<=N;j++){
                c[i][j]=0;
                for(int k=0;k<=N;k++)c[i][j]=M(c[i][j]+1ll*a[i][k]*b[k][j]%MOD);
            }
        }
        for(int i=0;i<=N;i++)for(int j=0;j<=N;j++)b[i][j]=c[i][j];
    }
    void solve(){
        N=sz*2+1;
        for(int i=0;i<=sz;i++){
            for(int j=1;j<=n;j++)if(~a[i][j]){
                if(size[j]==1)A[a[i][j]*2][i*2]++;
                else A[a[i][j]*2][i*2+1]++;
            }
            A[i*2+1][i*2]=1;
        }
        ANS[0][0]=1;
        while(L){
            if(L&1)mul(A,ANS);
            mul(A,A);
            L>>=1;
        }
        int ans=0;
        for(int i=0;i<=sz;i++)if(!val[i])ans=M(ans+ANS[i*2][0]);
        printf("%d",ans);
    }
}
int main(){
    scanf("%d%d%lld",&n,&m,&L);
    for(int i=1;i<=n;i++)scanf("%s",s[i]);
    for(int i=1;i<=m;i++){
        scanf("%s",S);
        insert(S);
    }
    AC_build();
    memset(a,-1,sizeof(a));
    for(int k=1;k<=n;k++){
        size[k]=strlen(s[k]);
        for(int i=0;i<=sz;i++){
            int u=i;
            for(int j=0;j<size[k];j++)if(!val[u])u=ch[u][s[k][j]-'a'];else break;
            if(!val[u])a[i][k]=u;
        }
    }
    if(L<=100)Task1::solve();else
    Task2::solve();
    return 0;
}
View Code

 

posted @ 2018-03-12 13:41  ONION_CYC  阅读(464)  评论(0编辑  收藏  举报