POJ 2778 DNA Sequence (AC自动机+矩阵快速幂)
Description
给出\(m\)个模式串,问长度为\(n\)的不包含任意模式串的文本串有多少个?
Input
第一行两个整数\(m\)和\(n\)(\(1 \leqslant m \leqslant 10\),\(1 \leqslant n \leqslant 2 \times 10^9\)),表示模式串的数量和要构造的文本串的长度。
接下来\(m\)行每行一个模式串。每个模式串的长度不会超过\(10\)。
Output
一个整数,表示满足条件的模式串的数量。结果模\(100000\)。
Sample Input
4 3
AT
AC
AG
AA
Sample Output
36
Solution
AC自动机题目,要构造长度为\(n\)的文本串,相当于在自动机上从起点走\(n\)步,要求不包含任意模式串,相当于不能经过任意一个自动机上模式串的终点。因为自动机也可以理解为一张图,故可以写出的AC自动机的邻接矩阵\(M\),其中\(M_{i,j}\)表示从点\(i\)一步到点\(j\)的边的数量,除去一步走到模式串终点的边。那么\(M^n\)就记录不经过一步到模式串终点的边的情况下走n步任意两点之间边的数量。累加从起点到每个点的走法的数量就是答案。
此题时限为1000ms比较紧张,发现了取模运算对整体运行时间的决定性的影响,若在矩阵乘法函数的三重for循环内的A[i][k] * B[k][j]后面加一次取模,整体运行时间增加接近一倍,甚者导致超时。
Code
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
using namespace std;
typedef long long ll;
const int INF = 0x3f3f3f3f;
const ll mod = 100000;
const int N = 105;
const int M = 105;
const int Z = 4;
struct Matrix
{
ll a[N][N], n;
Matrix() {}
Matrix(ll n) : n(n)
{
memset(a, 0, sizeof(a));
}
ll* operator[](int i)
{
return a[i];
}
};
Matrix operator*(Matrix A, Matrix B)
{
int n = A.n;
Matrix C(n);
for (int i = 1; i <= n; i++)
for (int j = 1; j <= n; j++)
for (int k = 1; k <= n; k++)
C[i][j] = (C[i][j] + A[i][k] * B[k][j]) % mod;
return C;
}
Matrix power(Matrix A, ll n, ll mod)
{
Matrix C(A.n);
for (int i = 1; i <= A.n; i++) C[i][i] = 1;
while (n)
{
if (n & 1) C = C * A;
A = A * A;
n >>= 1;
}
return C;
}
int f(char c)
{
switch (c)
{
case 'A': return 0;
case 'C': return 1;
case 'G': return 2;
case 'T': return 3;
}
}
int trie[M][Z], tot, fail[M];
bool ed[M];
int newnode()
{
memset(trie[tot], -1, sizeof(trie[tot]));
fail[tot] = -1;
ed[tot] = false;
return tot++;
}
void init()
{
tot = 0;
newnode();
}
void insert(char s[])
{
int len = strlen(s);
int p = 0;
for (int i = 0; i < len; i++)
{
int c = f(s[i]);
if (trie[p][c] == -1) trie[p][c] = newnode();
p = trie[p][c];
}
ed[p] = true;
}
queue<int> q;
void build()
{
while (!q.empty()) q.pop();
fail[0] = 0;
for (int i = 0; i < Z; i++)
{
if (trie[0][i] == -1) trie[0][i] = 0;
else fail[trie[0][i]] = 0, q.push(trie[0][i]);
}
while (!q.empty())
{
int p = q.front(); q.pop();
if (ed[fail[p]]) ed[p] = true;
for (int i = 0; i < Z; i++)
{
if (trie[p][i] == -1) trie[p][i] = trie[fail[p]][i];
else fail[trie[p][i]] = trie[fail[p]][i], q.push(trie[p][i]);
}
}
}
Matrix get_Matrix()
{
Matrix M(tot);
for (int i = 0; i < tot; i++)
for (int j = 0; j < Z; j++)
if (!ed[trie[i][j]])
M[i + 1][trie[i][j] + 1]++;
return M;
}
char s[M];
int main()
{
int m, n;
scanf("%d%d", &m, &n);
init();
while (m--) scanf("%s", s), insert(s);
build();
Matrix Mat = get_Matrix();
Mat = power(Mat, n, mod);
ll ans = 0;
for (int i = 1; i <= tot; i++) ans = (ans + Mat[1][i]) % mod;
printf("%lld\n", ans);
return 0;
}