BZOJ 1009 [HNOI2008]GT考试(矩阵快速幂优化DP+KMP)

题意:

求长度为n的不含长为m的指定子串的字符串的个数

1s, n<=1e9, m<=50

思路:

长见识了。。

设那个指定子串为s

f[i][j]表示长度为i的字符串(其中后j个字符与s的前j个字符一致的情况下)的方法数

若匹配到s串长度为i的后缀加一个字符num可以组成最长长度为j的后缀,设a[i][j]为num的方法数

例如,s为12312,a为

9 1 0 0 0 0
8 1 1 0 0 0
8 1 0 1 0 0
9 0 0 0 1 0
8 1 0 0 0 1

(i,j都是从0到m-1)

如a[1][2]表示从“1”到“12”可以加的字符方法数,显然加“2”才可以,所以a[1][2]=1

而a[2][0]表示从“12”到“”可以加的字符方法数:显然不能加“3”,不然s串会匹配到"123";也不能加“1”,不然s串会匹配成"1"。所以a[2][0]=8

求a矩阵的方法是kmp,感觉只可意会(我写不出来QAQ)

 

显然f[i][x]只能由f[i-1][k]转移而来,而k为多少,要看a数组了

然后状态转移方程为:$f[i][j] = f[i-1][0]*a[0][j]+f[i-1][1]*a[1][j] +\dots + f[i-1][m-1]*a[m-1][j]$

这个状态转移方程可以用矩阵快速幂来加速

答案就是$\sum f[n][i]$

代码:

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<string>
#include<stack>
#include<queue>
#include<deque>
#include<set>
#include<vector>
#include<map>
#include<functional>
    
#define fst first
#define sc second
#define pb push_back
#define mem(a,b) memset(a,b,sizeof(a))
#define lson l,mid,root<<1
#define rson mid+1,r,root<<1|1
#define lc root<<1
#define rc root<<1|1
#define lowbit(x) ((x)&(-x)) 

using namespace std;

typedef double db;
typedef long double ldb;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> PI;
typedef pair<ll,ll> PLL;

const db eps = 1e-6;
//const int mod = 1e9+7;
const int maxn = 2e3+100;
const int maxm = 2e6+100;
const int inf = 0x3f3f3f3f;
const db pi = acos(-1.0);

int a[60][60];
int f[60][60];
int n, m, mod;
char s[maxn];
int Next[maxn];


void mtpl(int a[60][60], int b[60][60], int s[60][60]){
    int tmp[60][60];
    for(int i = 0; i < m; i++){
        for(int j = 0; j < m; j++){
            tmp[i][j] = 0;
            for(int k = 0; k < m; k++){
                tmp[i][j]+=a[i][k]*b[k][j]%mod;
                tmp[i][j]%=mod;
            }
        }
    }
    for(int i = 0; i < m; i++){
        for(int j = 0; j < m; j++){
            s[i][j] = tmp[i][j];
        }
    }
    return;
}

void fp(int x){
    while(x){
        if(x&1)mtpl(f,a,f);
        mtpl(a,a,a);
        x>>=1;
    }
    return;
}


void kmp(){
    int fix = 0;
    for(int i = 2; i <= m; i++){
        while(fix && s[fix+1]!=s[i])fix=Next[fix];
        if(s[fix+1]==s[i])++fix;
        Next[i]=fix;
    }

    for(int i = 0; i < m; i++){
        for(char j = '0'; j <= '9'; j++){
            fix = i;
            while(fix&&s[fix+1]!=j)fix=Next[fix];
            if(j==s[fix+1])a[i][fix+1]++;
            else a[i][0]++;
        }
    }
    return;
}



int main(){
    scanf("%d %d %d", &n, &m, &mod);
    scanf("%s", s+1);
    mem(a, 0);
    kmp();
    mem(f,0);
    f[0][0]=1;
    fp(n);
    int ans = 0;
    for(int i = 0; i < m; i++){
        ans += f[0][i];
        ans%=mod;
    }
    printf("%d", ans);
    return 0;
}
/*
5
3 4 5 1 2
 */

 

posted @ 2018-11-05 20:18  wrjlinkkkkkk  阅读(152)  评论(0编辑  收藏  举报