自动机
自动机
定义
- 自动机是一种通过 状态 之间的 跳转 进行计算的数学模型。
- 当自动机接受一个输入字符时,它使用 状态转移函数,依据 当前所处的状态 和 输入的字符 跳转至下一个状态。
- 我们通常使用有向图表示一个有限状态自动机。此时,状态在有向图上以结点形式表示;状态转移函数表示为这张图上的有向边的集合。
示例:判断输入的二进制数奇偶性的自动机
-
这个自动机仅有两种状态:\(q_0\) 和 \(q_1\)。起始状态是 \(q_0\)。
-
仅接受两种字符 \(0, 1\)。
-
状态转移函数:
\[\delta (q_0, 0) = q_0 \]\[\delta(q_0, 1) = q_1 \]\[\delta(q_1, 0) = q_0 \]\[\delta(q_1, 1) = q_1 \] -
从起始状态 \(q_0\) 出发,将一个二进制数从高位至低位输入自动机;如果抵达状态 \(q_0\),则数字是偶数;否则是奇数。
KMP 与前缀自动机
我们可以快速回顾一下 KMP。
首先,我们定义前缀函数 \(p_i\) 表示子串 \(s_1 \dots s_i\) 的最长 border 长度,这个在 KMP 中也有用到。
我们记模式串 \(t\) 的长度为 \(n\),构造一个包含 \(q_0, q_1, \dots, q_n\) 共计 \(n + 1\) 种状态的自动机,状态 \(q_i\) 对应匹配到 \(t\) 的第 \(i\) 个字符。
我们也可以一边算状态转移函数,一边算前缀函数:\(p_{i + 1} = \delta(p_i, t_{i + 1})\)。
假设 \(t = a \ b \ a \ a \ b\)
CF1721E
题意
给定一个字符串 \(s\),有 \(q\) 次查询,每次查询给定一个长度不超过 \(10\) 的字符串 \(t\),执行以下操作:
- 连接 \(s\) 和 \(t\)。
- 计算字符串 \(s + t\) 的前缀函数。
- 输出前缀函数在 \(\vert s \vert + 1, \vert s \vert + 2, \dots, \vert s \vert + \vert t \vert\) 的值。
- 将字符串恢复成 \(s\)。
思路
首先,我们可以想到一种暴力的思路,那就是每次都把字符串连起来,暴力的用 KMP 的方法算出 \(p_i \ (\vert s \vert + 1 \le i \le \vert s \vert + \vert t \vert)\)
但是,KMP 只保证总时间复杂度的正确性,并不保证计算单个前缀函数的复杂度。所以,我们可以考虑用前缀自动机算出前缀函数,可以保证复杂度。
代码
#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 50;
string s, t;
int q, n, m, p[N], nxt[N][26];
void Solve() {
cin >> t;
m = t.size(), t = ' ' + t;
for (int i = n; i <= n + m; i++) {
for (int j = 0; j < 26; j++) {
if (i < n + m && j == t[i + 1 - n] - 'a') {
nxt[i][j] = i + 1;
p[i + 1] = nxt[p[i]][j];
} else nxt[i][j] = nxt[p[i]][j];
}
}
for (int i = 1; i <= m; i++) cout << p[n + i] << ' ';
cout << '\n';
}
int main() {
ios::sync_with_stdio(0), cin.tie(0);
cin >> s >> q;
n = s.size(), s = ' ' + s;
nxt[0][s[1] - 'a'] = 1;
for (int i = 1; i <= n; i++) {
for (int j = 0; j < 26; j++) {
if (i < n && j == s[i + 1] - 'a') {
nxt[i][j] = i + 1;
p[i + 1] = nxt[p[i]][j];
} else nxt[i][j] = nxt[p[i]][j];
}
}
while (q--) Solve();
return 0;
}
子序列自动机
状态
给定文本串 \(s\) 和模式串 \(t\),判断 \(s\) 是否是 \(t\) 的一个子序列。
我们考虑对 \(t\) 构造一个包含 \(n + 2\) 个状态的自动机:\(q_0, q_1, \dots, q_n, q_{-1}\)。
当对自动机输入串 \(s\) 后,停留在状态 \(q_i \ (1 \le i \le n)\) 时,表示字符串 \(s\) 是 \(t_1, t_2, \dots, t_i\) 的一个子序列。
停留在状态 \(q_{-1}\) 表示字符串 \(s\) 不是 \(t\) 的一个子序列。
转移
\(\delta(q_i, c)\) 是字符串 \(t\) 中从第 \(i\) 个字符往后的字符 \(c\) 第一次出现的位置。
如果不存在 \(c\),那么 \(\delta(q_i, c)\) 为 \(q_{-1}\)。
假设 \(t = a \ b \ a \ a\ b\),那么子序列自动机就是这样的:
洛谷 P1819
题意
求出 \(3\) 个字符串有多少个不同的公共子序列,不包括空序列。
思路
我们先对这 \(3\) 个字符串分别建出三个子序列自动机。
设 \(dp_{i, j, k}\) 为从第一个字符串的 \(q_{1, i}\) 状态,第二个字符串的 \(q_{2, j}\) 状态,第三个字符串的 \(q_{3, k}\) 状态出发的公共子序列数量。
然后枚举下一个选择哪一个字符,对于每一个 \(c\),将 \(dp_{i, j, k}\) 更新为 \(dp_{i, j, k} + dp_{\delta(q_{1,i}, c), \delta(q_{2, j}, c), \delta(q_{3, k}, c)}\) 即可。
代码
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 160, mod = 1e8;
int n, fs[N][26], ft[N][26], fstr[N][26], last[26];
ll dp[N][N][N];
string s, t, str;
ll dfs(int i, int j, int k) {
if (i > n || j > n || k > n) return 0;
if (dp[i][j][k]) return dp[i][j][k];
for (int p = 0; p < 26; p++) {
dp[i][j][k] = (dp[i][j][k] + dfs(fs[i][p], ft[j][p], fstr[k][p])) % mod;
}
dp[i][j][k] = (dp[i][j][k] + 1) % mod;
return dp[i][j][k];
}
int main() {
ios::sync_with_stdio(0), cin.tie(0);
cin >> n >> s >> t >> str;
s = ' ' + s, t = ' ' + t, str = ' ' + str;
fill(last, last + 26, n + 1);
for (int i = n; i >= 0; i--) {
for (int j = 0; j < 26; j++) fs[i][j] = last[j];
if (i) last[s[i] - 'a'] = i;
}
fill(last, last + 26, n + 1);
for (int i = n; i >= 0; i--) {
for (int j = 0; j < 26; j++) ft[i][j] = last[j];
if (i) last[t[i] - 'a'] = i;
}
fill(last, last + 26, n + 1);
for (int i = n; i >= 0; i--) {
for (int j = 0; j < 26; j++) fstr[i][j] = last[j];
if (i) last[str[i] - 'a'] = i;
}
cout << (dfs(0, 0, 0) + mod - 1) % mod;
return 0;
}
abc299 f
题意
给定一个字符串 \(s\),你需要算出存在多少种字符串 \(t\),使得 \(tt\) 是 \(s\) 的子序列。
思路
首先,我们会发现,当某个字符串 \(tt\) 是 \(s\) 的子序列时,我们是可以将 \(s\) 分为两个部分,使得这两个部分分别包含一个 \(t\) 的。
那么,我们先考虑枚举分割点,也就是说枚举每一个 \(p \ (1 \le p < |s|)\),然后求 \([1, p], [p + 1, |s|]\) 的公共子序列数量。
主要是需要注意不能计算相同的 \(t\),所以我们强制让第一个出现的 \(t\) 的右端点为 \(p\) 即可。
代码
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 110, mod = 998244353;
string s;
int n, fs[N][26], ft[N][26], last[26];
ll dp[N][N], ans;
int main() {
ios::sync_with_stdio(0), cin.tie(0);
cin >> s, n = s.size(), s = ' ' + s;
for (int p = 1; p <= n; p++) {
fill(last, last + 26, n + 1);
for (int i = p; i >= 0; i--) {
for (int j = 0; j < 26; j++) {
fs[i][j] = last[j];
}
last[s[i] - 'a'] = i;
}
fill(last, last + 26, n + 1);
for (int i = n; i >= p; i--) {
for (int j = 0; j < 26; j++) {
ft[i][j] = last[j];
}
last[s[i] - 'a'] = i;
}
for (int i = p; i >= 0; i--) {
for (int j = n; j >= p; j--) {
if (s[i] != s[j] && (i || j != p)) continue;
if (i == p) (dp[i][j] += 1) %= mod;
for (int k = 0; k < 26; k++) {
if (fs[i][k] <= n && ft[j][k] <= n) {
(dp[i][j] += dp[fs[i][k]][ft[j][k]]) %= mod;
}
}
}
}
(ans += dp[0][p]) %= mod;
for (int i = p; i >= 0; i--) {
for (int j = n; j >= p; j--) {
dp[i][j] = 0;
}
}
}
cout << ans;
return 0;
}