BJWC2011 禁忌
题解
多模式匹配首先建 AC 自动机,看到 \(len \le 10^9\) 想到矩阵乘法优化。
朴素 DP
关于分割的最大值,可以贪心,只要走到一个能匹配串的点立刻返回根继续匹配就行,一定能保证最优。
以最后的结果枚举算期望显然是 \(\text{alphaset} ^ {len}\) 的,显然不可取。由于期望线性,不妨算贡献。
设 \(f[i][j]\) 为长度为 \(i\) 的字符串,走到 AC 自动机上对应节点为 \(j\) 的概率。
转移就是在 AC 自动机上枚举字符集,如果转移到的点能匹配( Fail 链上有禁忌串),则返回根,设这个转移点 \((u, v)\),即 \(f[i][v] = \sum f[i - 1][u] \times \frac{1}{\text{alphaset}}\)。
\(ans = \sum_{i = 0}^{len - 1} f[i][u] \times \frac{1}{\text{alphaset}}\) 满足有边 \((u, rt)\) 的。
复杂度 \(O(75\text{len})\)
矩阵优化
AC 自动机上的节点数最多为 \(75\) 个,且其实贡献是一个相加形式,矩阵优化应该是可行的。
考虑边递推每一层的同时维护 \(ans\),即矩阵多加一列, 即构造矩阵 \([F_i, ans] \times A = [F_{i + 1}, ans]\)
-
对于一条边 \((u, v)\) 贡献:\(A[u][j] \Leftarrow + \frac{1}{\text{alphaset}}\)
-
特别地若这条边 \((u, rt)\),对答案有贡献:\(A[u][idx + 1] \Leftarrow + \frac{1}{\text{alphaset}}\)
-
注意 \(ans\) 本身要传递至下一层:\(A[idx + 1][idx + 1] \Leftarrow +1\)
时间复杂度 \(O(75 ^ 3log_2{\text{len}})\)
注意此题卡精度,所有地方包括 \(\frac{1}{\text{alphaset}}\) 都要开 long double
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
typedef long double LD;
const int L = 80, S = 20;
int n, len, m;
int tr[L][26], fail[L], q[L], idx;
bool e[L];
char s[S];
struct Mat{
LD w[L][L];
int n, m;
Mat operator * (const Mat &b) const {
Mat c; c.n = n, c.m = b.m;
for (int i = 0; i <= n; i++) {
for (int j = 0; j <= c.m; j++) {
c.w[i][j] = 0;
for (int k = 0; k <= m; k++)
c.w[i][j] += w[i][k] * b.w[k][j];
}
}
return c;
}
void print() {
puts("Matrix !");
for (int i = 0; i <= n; i++) {
for (int j = 0; j <= m; j++) printf("%.2Lf ", w[i][j]);
puts("");
}
}
} res, A;
void insert() {
int p = 0;
for (int i = 1; s[i]; i++) {
int ch = s[i] - 'a';
if (!tr[p][ch]) tr[p][ch] = ++idx;
p = tr[p][ch];
}
e[p] = true;
}
void build() {
int hh = 0, tt = -1;
for (int i = 0; i < m; i++)
if (tr[0][i]) q[++tt] = tr[0][i];
while (hh <= tt) {
int u = q[hh++];
for (int i = 0; i < m; i++) {
int v = tr[u][i];
if (v) {
fail[v] = tr[fail[u]][i];
if (e[fail[v]]) e[v] = true;
q[++tt] = v;
} else tr[u][i] = tr[fail[u]][i];
}
}
}
int main() {
scanf("%d%d%d", &n, &len, &m);
for (int i = 1; i <= n; i++) {
scanf("%s", s + 1);
insert();
}
build();
res.n = 0, res.m = A.n = A.m = idx + 1;
res.w[0][0] = 1; A.w[idx + 1][idx + 1] = 1;
for (int u = 0; u <= idx; u++) {
for (int i = 0; i < m; i++) {
int v = tr[u][i];
if (e[v]) {
A.w[u][idx + 1] += (LD)1 / m;
v = 0;
}
A.w[u][v] += (LD)1 / m;
}
}
while (len) {
if (len & 1) res = res * A;
A = A * A;
len >>= 1;
}
printf("%Lf\n", res.w[0][idx + 1]);
}