HDU 2243 考研路茫茫——单词情结(AC自动机 + 矩阵快速幂)题解

题意:找出所有长度不大于L的,包含至少一个模式串的主串的个数。

思路:和2778类似,但是这里求1~L所有长度的种数。所以我们只要求出来不包含的所有个数就行。

假设AC自动机上所有节点的邻接矩阵为A,那么答案为$\sum_{i=1}^n 26^i - \sum_{i=1}^n A^i$。

因为L有点大,那么我们可以直接用矩阵快速幂来求:

令$S_n = \sum_{i=1}^n 26^i ,T_n = \sum_{i=1}^n A^i$

$$ \left[ \begin{matrix} T_{n-1} & A \\ 0 & 0 \end{matrix} \right] * \left[ \begin{matrix} A & 0 \\ E & E \end{matrix} \right] = \left[ \begin{matrix} T_{n} & A \\ 0 & 0 \end{matrix} \right]$$

$$ \left[ \begin{matrix} S_{n-1} & 26 \\ 0 & 0 \end{matrix} \right] * \left[ \begin{matrix} 26 & 0 \\ 1 & 1 \end{matrix} \right] = \left[ \begin{matrix} S_{n} & 26 \\ 0 & 0 \end{matrix} \right]$$

求$T_n$的时候直接开一个大矩阵求就行了

 

代码:

#include<cmath>
#include<set>
#include<map>
#include<queue>
#include<cstdio>
#include<vector>
#include<cstring>
#include <iostream>
#include<algorithm>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int maxn = 30 + 5;
const int M = 50 + 5;
const ull seed = 131;
const double INF = 1e20;
const int MOD = 100000;
int n;
ll L;
struct Mat{
    ull s[maxn * 2][maxn * 2];
    void init(){
        for(int i = 0; i < maxn * 2; i++)
            for(int j = 0; j < maxn * 2; j++)
                s[i][j] = 0;
    }
};
Mat mul(Mat &a, Mat &b, int tn){
    Mat t;
    t.init();
    for(int i = 0; i < tn; i++){
        for(int j = 0; j < tn; j++){
            for(int k = 0; k < tn; k++){
                t.s[i][j] = t.s[i][j] + a.s[i][k] * b.s[k][j];
            }
        }
    }
    return t;
}
Mat ppow(Mat a, ll b, int tn){
    Mat ret;
    ret.init();
    for(int i = 0; i < maxn * 2; i++) ret.s[i][i] = 1;
    while(b){
        if(b & 1) ret = mul(ret, a, tn);
        a = mul(a, a, tn);
        b >>= 1;
    }
    return ret;
}
struct Aho{
    struct state{
        int next[26];
        int fail, cnt;
    }node[maxn];
    int size;
    queue<int> q;

    void init(){
        size = 0;
        newtrie();
        while(!q.empty()) q.pop();
    }

    int newtrie(){
        memset(node[size].next, 0, sizeof(node[size].next));
        node[size].cnt = node[size].fail = 0;
        return size++;
    }

    void insert(char *s){
        int len = strlen(s);
        int now = 0;
        for(int i = 0; i < len; i++){
            int c = s[i] - 'a';
            if(node[now].next[c] == 0){
                node[now].next[c] = newtrie();
            }
            now = node[now].next[c];
        }
        node[now].cnt = 1;
    }

    void build(){
        node[0].fail = -1;
        q.push(0);

        while(!q.empty()){
            int u = q.front();
            q.pop();
            if(node[node[u].fail].cnt && u) node[u].cnt = 1;   //都不能取
            for(int i = 0; i < 26; i++){
                if(!node[u].next[i]){
                    if(u == 0)
                        node[u].next[i] = 0;
                    else
                        node[u].next[i] = node[node[u].fail].next[i];
                }
                else{
                    if(u == 0) node[node[u].next[i]].fail = 0;
                    else{
                        int v = node[u].fail;
                        while(v != -1){
                            if(node[v].next[i]){
                                node[node[u].next[i]].fail = node[v].next[i];
                                break;
                            }
                            v = node[v].fail;
                        }
                        if(v == -1) node[node[u].next[i]].fail = 0;
                    }
                    q.push(node[u].next[i]);
                }
            }
        }
    }

    ull query(){
        Mat A;
        A.init();
        for(int i = 0; i < size; i++){
            for(int j = 0; j < 26; j++){
                if(node[node[i].next[j]].cnt == 0){
                    A.s[i][node[i].next[j]]++;
                }
            }
        }
        Mat a, b;
        a.init(), b.init();
        for(int i = 0; i < size; i++){
            for(int j = 0; j < size; j++){
                a.s[i][j] = a.s[i][j + size] = b.s[i][j] = A.s[i][j];
            }
        }
        for(int i = 0; i < size; i++){
            b.s[i + size][i] = b.s[i + size][i + size] = 1;
        }

        b = ppow(b, L - 1, 2 * size);
        a = mul(a, b, 2 * size);
        ull ret = 0;
        for(int i = 0; i < size; i++){
            if(node[i].cnt == 0) ret += a.s[0][i];
        }
        return ret;
    }

}ac;
char s[20];
int main(){
    while(~scanf("%d%lld", &n, &L)){
        ull ans = 0;
        Mat a, b;
        a.init(), b.init();
        a.s[0][0] = a.s[0][1] = 26;
        b.s[0][0] = 26, b.s[1][0] = 1, b.s[1][1] = 1;
        b = ppow(b, L - 1, 4);
        a = mul(a, b, 4);
        ans = a.s[0][0];

        ac.init();
        while(n--){
            scanf("%s", s);
            ac.insert(s);
        }
        ac.build();
        ull ret = ac.query();
        cout << ans - ret << endl;
    }
    return 0;
}

 

posted @ 2019-07-13 16:17  KirinSB  阅读(192)  评论(0编辑  收藏  举报