【BZOJ #3881】「COCI2015」Divljak
Description
有 \(n\) 个字符串 \(S_1, S_2, \cdots, S_n\)。以及一个字典 \(T\),一开始字典 \(T\) 为空。
接下来有 \(q\) 个操作,操作包含以下两种:
1 P
:向字典 \(T\) 中插入一个字符串 \(P\)。2 x
:请你求出 \(S_x\) 是字典 \(T\) 中多少个串的子串。
数据范围:\(1 \leq n, q \leq 10^5\),字符集为小写字母集,字符串总长 \(\leq 2 \times 10^6\)。
时空限制:\(4000 \ \text{ms} / 500 \ \text{MiB}\)。
Solution
操作 2 是一个多模匹配问题。
考虑将 \(S_1, S_2, \cdots, S_n\) 建出 AC 自动机。
对 AC 自动机上的每一个点 \(u\),求出它的失配指针 \(fail_u\),将 \(fail_u \to u\) 连边,即可得到一棵 \(fail\) 树。
对于字典 \(T\) 中新插入的一个字符串 \(P\),考虑求 \(P\) 对 \(S_1, S_2, \cdots, S_n\) 中的哪些字符串造成贡献。
考虑文本串 \(P\) 匹配的过程:\(P\) 在 AC 自动机上一个字符一个字符走的过程,相当于枚举了一个前缀,任意时刻在 AC 自动机(trie 图)上走到的节点 \(u\) 代表的字符串,即为该前缀与自动机匹配的最长后缀。那么,我们考虑在 \(u\) 节点这个位置向上跳 \(fail\) 指针,根据 \(fail\) 指针的定义,路径上经过的每一个节点代表的字符串都是 \(P\) 的子串。
设 \(P\) 在 AC 自动机上依次走到了节点 \(u_1, u_2, \cdots, u_k\)。
那么 \(P\) 在 AC 自动机上能匹配到的子串位于 \(u_1, u_2, \cdots, u_k\) 在 \(fail\) 树上到根节点上的链的点集并。
那么现在要做的是将该点集内的所有点的答案加上 \(1\),本质上是一个树链求并。
注意到根节点是固定的,可以考虑将 \(u_1, u_2, \cdots, u_k\) 按照在 \(fail\) 树中的 dfs 序排序后,做下面的事情:
- 对于每个 \(1 \leq i \leq k\),将 \(u_i\) 在 \(fail\) 树上到根节点上的链的所有点的答案加上 \(1\)。
- 对于每个 \(1 \leq i < k\),将 \(\text{lca}(u_i, u_{i + 1})\) 在 \(fail\) 树上到根节点上的链的所有点的答案减去 \(-1\)。
现在问题转化为了:" 路径加 " & " 单点求值 "。
可以使用树上差分将问题转化为:" 单点加 " & " 子树求和 "。
在 dfs 序上维护一个树状数组即可实现上述操作。
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
using namespace std;
const int N = 100100, SIZE = 2001000;
int n, m;
char S[N];
int cT = 1;
struct AC {
int trans[26];
int fail;
} t[SIZE];
int End[N];
void insert(char *S, int id) {
int p = 1, len = strlen(S + 1);
for (int i = 1; i <= len; i ++) {
int v = S[i] - 'a';
if (!t[p].trans[v]) t[p].trans[v] = ++ cT;
p = t[p].trans[v];
}
End[id] = p;
}
void GetFail() {
for (int i = 0; i < 26; i ++)
t[0].trans[i] = 1;
queue<int> q;
q.push(1), t[1].fail = 0;
while (q.size()) {
int u = q.front(); q.pop();
for (int i = 0; i < 26; i ++) {
if (t[u].trans[i])
t[t[u].trans[i]].fail = t[t[u].fail].trans[i], q.push(t[u].trans[i]);
else
t[u].trans[i] = t[t[u].fail].trans[i];
}
}
}
int tot, head[SIZE], ver[SIZE], Next[SIZE];
void addedge(int u, int v) {
ver[++ tot] = v; Next[tot] = head[u]; head[u] = tot;
}
int d[SIZE];
int size[SIZE];
int son[SIZE];
void dfs1(int u) {
size[u] = 1;
for (int i = head[u]; i; i = Next[i]) {
int v = ver[i];
d[v] = d[u] + 1;
dfs1(v);
size[u] += size[v];
if (size[v] > size[son[u]]) son[u] = v;
}
}
int ovo, dfn[SIZE];
int top[SIZE];
void dfs2(int u) {
dfn[u] = ++ ovo;
if (son[u]) {
top[son[u]] = top[u];
dfs2(son[u]);
}
for (int i = head[u]; i; i = Next[i]) {
int v = ver[i];
if (v == son[u]) continue;
top[v] = v;
dfs2(v);
}
}
int lca(int x, int y) {
while (top[x] != top[y]) {
if (d[top[x]] > d[top[y]]) swap(x, y);
y = t[top[y]].fail;
}
if (d[x] > d[y]) swap(x, y);
return x;
}
int c[SIZE];
void add(int x, int val) {
for (; x <= cT; x += x & -x) c[x] += val;
}
int ask(int x) {
int ans = 0;
for (; x; x -= x & -x) ans += c[x];
return ans;
}
int seq[SIZE];
bool cmp(int i, int j) {
return dfn[i] < dfn[j];
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i ++) {
scanf("%s", S + 1);
insert(S, i);
}
scanf("%d", &m);
GetFail();
for (int i = 2; i <= cT; i ++)
addedge(t[i].fail, i);
d[1] = 1, dfs1(1);
top[1] = 1, dfs2(1);
while (m --) {
int opt, x;
scanf("%d", &opt);
switch (opt) {
case 1: {
scanf("%s", S + 1);
int p = 1, len = strlen(S + 1);
for (int i = 1; i <= len; i ++) {
int v = S[i] - 'a';
p = t[p].trans[v];
seq[i] = p;
}
sort(seq + 1, seq + 1 + len, cmp);
for (int i = 1; i <= len; i ++) {
int p = seq[i];
add(dfn[p], 1);
}
for (int i = 1; i < len; i ++) {
int p = seq[i], q = seq[i + 1];
add(dfn[lca(p, q)], -1);
}
break;
}
case 2: {
scanf("%d", &x);
int p = End[x];
printf("%d\n", ask(dfn[p] + size[p] - 1) - ask(dfn[p] - 1));
break;
}
}
}
return 0;
}