计蒜客 疑似病毒 (AC自动机 + 可达矩阵)

/*************************************************************************
	> File Name: ac_machine.cpp
	> Author: 
	> Mail: 
	> Created Time: 2017年11月25日 星期六 11时01分11秒
 ************************************************************************/

#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
using namespace std;

const int MAX_N = 10000;
const int MAT_X = 200;
const int SIZE = 4;
const int BASE = '0'; 
const int LAYER = 3;
const int MOD = 10007;
int node_cnt = 1;  // 统计节点数量用于分配index, 在insert中统计, 在automaton中分配
int mat_size;

// Matrix 结构声明
struct Matrix {
    Matrix() {
        memset(this->m, 0, sizeof(this->m));
    }
    int m[MAT_X][MAT_X];
    void show() {
        for (int i = 0 ; i < mat_size ; ++i) {
            for (int j = 0 ; j < mat_size ; ++j) {
                printf("%4d ", this->m[i][j]);
            }
            printf("\n");
        }
    }
};
Matrix unit_mat;
Matrix CGMatrix; // 可达矩阵

void init_unint_mat() {
    for (int i = 0 ; i < MAT_X ; ++i) {
        unit_mat.m[i][i] = 1;
    }
}

Matrix operator* (const Matrix &a, const Matrix &b) {
    Matrix c;
    for (int i = 0 ; i < mat_size ; ++i) {
        for (int j = 0 ; j < mat_size ; ++j) {
            int sum = 0;
            for (int k = 0 ; k < mat_size ; ++k) {
                sum += (a.m[i][k] * b.m[k][j]);
                sum %= MOD;
            }
            c.m[i][j] = sum % MOD;
        }
    }
    return c;
}

// 快速幂 : 计算矩阵a的x次方
Matrix quick_mat_pow(Matrix a, int x) {
    Matrix ret = unit_mat;
    while (x) {
        // TODO
        if (x & 1) ret = ret * a;
        a = a * a;
        x >>= 1;
    }
    return ret;
}

// Trie 结构声明
// base_val : 基础值
// cal_val  : 计算值
// index    : 下标
typedef struct Trie {
    int base_val;
    int cal_val[LAYER];
    int index[LAYER];
    struct Trie *fail;
    struct Trie **childs;
} Node, *Tree;

Trie* new_node() {
    Trie *p = (Trie *)malloc(sizeof(Trie));
    p->childs = (Trie **)malloc(sizeof(Trie *) * SIZE);
    memset(p->childs, 0, sizeof(Trie *) * SIZE);
    p->fail = NULL;
    p->base_val = 0;
    memset(p->index, 0, sizeof(int) * LAYER);
    memset(p->cal_val, 0, sizeof(int) * LAYER);
    return p;
}

void clear(Trie *node) {
    if (node == NULL) return;
    for (int i = 0 ; i < SIZE ; ++i) {
        clear(node->childs[i]);
    }
    free(node->childs);
    free(node);
}

void insert(Trie *node, char *str) {
    Trie *p = node;
    for (int i = 0 ; str[i] ; ++i) {
        if (p->childs[str[i] - BASE] == NULL) {
            p->childs[str[i] - BASE] = new_node();
            // 更新节点数量
            ++node_cnt;
        }
        p = p->childs[str[i] - BASE];
    }
    p->base_val = 1;
}

// 建立自动机
void build_automaton(Trie *root) {
    root->fail = NULL;
    Trie *que[MAX_N];
    int l = 0, r = 0, k = 0; // k用于计算index
    que[r++] = root;
    while (l < r) {
        Trie *now = que[l++];
        
        // 更新三层下标
        now->index[0] = k;
        now->index[1] = k + node_cnt;
        now->index[2] = k + node_cnt * 2;
        ++k;
        
        for (int i = 0 ; i < SIZE ; ++i) {
            if (now->childs[i] == NULL) continue;
            Trie *child = now->fail;
            while (child && child->childs[i] == NULL) {
                child = child->fail;
            }
            if (child == NULL) {
                child = root;
            } else {
                child = child->childs[i];
            }
            now->childs[i]->fail = child;
            now->childs[i]->base_val += now->childs[i]->fail->base_val;
            que[r++] = now->childs[i];
        }
    }
}

// 判断在第几个自动机
int inLayer(int x) {
    return (x <= 2 ? x : 2);
}

// 得到孩子的下标
int getChildIndex(Trie *now, Trie *child, int now_ind) {
    return child->index[inLayer(now->cal_val[now_ind / node_cnt] + child->base_val)];
}

// 更新计算权值
int updataCalVal(Trie *now, Trie *child, int now_ind) {
    return now->cal_val[now_ind / node_cnt] + child->base_val;
}

// BFS初始化可达矩阵
void BFS(Trie *root) {
    Trie *que[MAX_N * 3];
    int ind[MAX_N * 3];
    int vis[MAX_N * 3] = {0};
    int que_l = 0, que_r = 0;
    int ind_l = 0, ind_r = 0;
    que[que_r++] = root;
    ind[ind_r++] = 0; 
    while (ind_l < ind_r) {
        Trie *now = que[que_l++];
        int now_ind = ind[ind_l++];
        vis[now_ind] = 1;
        for (int i = 0 ; i < SIZE ; ++i) { 
            Trie *child;
            if (!now->childs[i]) {
                // 寻找失败指针中是否出现childs[i], 如果没出现过, 那么就会走回root节点
                Trie *p_fail = now->fail;
                while (p_fail != NULL && p_fail->childs[i] == NULL) {
                    p_fail = p_fail->fail;
                }
                // 如果p_fail == NULL 那么这个一定为root
                if (p_fail == NULL) {
                    child = root;
                } else {
                    child = p_fail->childs[i];
                }
            } else {
                child = now->childs[i];
            }
            int child_ind = getChildIndex(now, child, now_ind);
            child->cal_val[child_ind / node_cnt] = updataCalVal(now, child, now_ind);
            CGMatrix.m[now_ind][child_ind]++;
            if (vis[child_ind] == 0) {
                ind[ind_r++] = child_ind;
                que[que_r++] = child;
                vis[child_ind] = 1;
            }
        }
    }
}

// 转换函数 : 将ATCG转换为0123
void transStr(char *str) {
    for (int j = 0 ; str[j] ; ++j) {
        switch(str[j]) {
            case 'A' :
                str[j] = '0';
                break;
            case 'T' :
                str[j] = '1';
                break;
            case 'C' :
                str[j] = '2';
                break;
            case 'G' :
                str[j] = '3';
                break;
        }
    }
}

int main() {
    int n, L;
    char str[200];
    while (scanf("%d%d", &n, &L) != EOF) {
        Trie *root = new_node();
        node_cnt = 1;
        for (int i = 0 ; i < n ; ++i) {
            getchar();
            scanf("%s", str);
            transStr(str);
            insert(root, str);
        }
        build_automaton(root);
    
        // 设置矩阵大小
        mat_size = node_cnt * 3;
        init_unint_mat();

        BFS(root);
        CGMatrix = quick_mat_pow(CGMatrix, L);

        int ans = 0;
        for (int j = 2 * node_cnt ; j < mat_size ; ++j) {
            ans += CGMatrix.m[0][j];
            ans %= MOD;
        }
        printf("%d\n", ans);
        memset(CGMatrix.m, 0, sizeof(CGMatrix.m));
        clear(root);
    }
    return 0;
}
posted @ 2017-11-29 18:57  ojnQ  阅读(228)  评论(0编辑  收藏  举报