P3193 [HNOI2008]GT考试(KMP+矩阵乘法加速dp)

P3193 [HNOI2008]GT考试

思路:

\(dp(i,j)\)\(N\)位数从高到低第\(i\)位时,不吉利数字在第\(j\)位时的情况总数,那么转移方程就为:

\[dp(i,j)=dp(i+1,k)*a(j,k) \]

这里\(a(j,k)\)就是从第\(j\)位到第\(k\)位的情况总数。那么根据这个转移方程我们就可以直接求解了。但是题目中\(N\)的范围过大,直接枚举可能要爆炸,我们这样考虑,将dp方程稍微变化一下:

\[dp(i,j)=\sum_{k=1}^m dp(i-1,k)*a(k,j) \]

那么这里的\(a(k,j)\)就相当于矩阵的一列,我们将\(i-1\)的状态与每一列相乘就可以得到\(i\)的所有状态。那么我们矩阵加速一下就好了。
注意在构造矩阵的时候,不要考虑最后一位就行了,这样就匹配成功了。

代码如下:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 25, MAX = 10;
int n, m, K;
char s[N];
int nxt[N][MAX];
struct Matrix{
    int n;
    ll A[N][N];
    Matrix() {
        memset(A, 0, sizeof(A)) ;
    }
    void Print() {
        for(int i = 0; i <= n; i++) {
            for(int j = 0; j <= n; j++) {
                cout << A[i][j] << ' ' ;
            }
            cout << '\n' ;
        }
    }
}trans, dp;
Matrix operator * (const Matrix &a, const Matrix &b) {
    Matrix ans;
    ans.n = a.n;
    for(int i = 0; i <= ans.n; i++)
        for(int j = 0; j <= ans.n; j++)
            for(int k = 0; k <= ans.n; k++)
                ans.A[i][j] = (ans.A[i][j] + a.A[i][k] * b.A[k][j]) % K;
    return ans;
}
Matrix qp(Matrix a, ll b) {
    Matrix ans; ans.n = a.n;
    for(int i = 0; i <= ans.n; i++) ans.A[i][i] = 1;
    while(b) {
        if(b & 1) ans = ans * a;
        a = a * a;
        b >>= 1;
    }
    return ans ;
}
void Get_nxt(char *s, int nxt[][MAX]) {
    int L = strlen(s + 1) ;
    for(int k = 0; k < L; k++) {
        int l = k + 1;
        for(int p = 0; p < MAX; p++) {
            for(int i = min(l, L); i >= 0; i--) {
                bool flag = true;
                for(int j = 1; j < i; j++)
                    if(s[j] != s[l - i + j]) flag = false ;
                if(s[i] - '0' != p) flag = false ;
                if(flag) {
                    nxt[k][p] = i;
                    break ;
                }
            }
        }
    }
}
int main() {
    cin >> n >> m >> K;
    scanf("%s",s + 1) ;
    Get_nxt(s, nxt) ;
    int L = strlen(s + 1) ;
    trans.n = dp.n = m;
    for(int i = 0; i < m; i++) {
        for(int p = 0; p < MAX; p++) {
            int j = nxt[i][p] ;
            if(j != m) trans.A[i][j]++;
        }
    }
    dp.A[0][0] = 1 ;
    trans = qp(trans, n) ;
    ll ans = 0;
    dp = dp * trans;
    for(int i = 0; i < m; i++)
        ans = (ans + dp.A[0][i]) % K ;
    cout << ans ;
    return 0;
}
posted @ 2019-05-16 20:06  heyuhhh  阅读(251)  评论(0编辑  收藏  举报