SP8093 JZPGYZ - Sevenk Love Oimaster
Description
给定 \(n\) 个模板串,以及 \(m\) 个查询串。
依次查询每一个查询串是多少个模板串的子串。
Solution
发现题解区好多用 SAM 的,这里提供一篇AC自动机做法的题解。
我们先对询问串建AC自动机,然后计算每一个文本串对每一个询问串的贡献。
那么这个贡献要如何计算呢?
考虑对于一个文本串 \(S\)。我们先枚举它的每一位,相当于是在枚举前缀。
从每一位跳 \(fail\) 指针时,每跳到一个节点,这个节点所代表的字符串都是 \(S\) 的一个后缀,也就是 \(S\) 的一个字串。
现在我们要求的就是 \(S\) 在 AC 自动机上跳到的每一个点到根的链上的点集并(多读几遍吧……)。
容易想到树上差分来维护,即在 \(x\) 点,\(y\) 点上 +1,\(lca(x, y)\) 上 -1。
但暴力是 \(n^2\) 的,考虑优化。我们可以把 \(fail\) 树的 dfs 序计算出来,然后就变成了区间加以及区间求和。
所以直接用线段树或树状数组维护一下即可。
我用的树状数组。
Code
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
using namespace std;
inline int read(){
int x = 0; char ch = getchar();
while(ch < '0' || ch > '9') ch = getchar();
while(ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
return x;
}
const int N = 1e5 + 10;
const int M = 4e5 + 5;
int n, m;
int trie[M][27], tot = 1, End[M];
char s[M], tmp[10010][N];
int fail[M];
inline void insert(char s[], int id){
int len = strlen(s + 1), now = 1;
for(int i = 1; i <= len; i++){
int c = s[i] - 'a';
if(!trie[now][c]) trie[now][c] = ++tot;
now = trie[now][c];
}
End[id] = now;
}
inline void build(){
queue <int> q;
for(int i = 0; i < 26; i++)
trie[0][i] = 1;
q.push(1);
fail[1] = 0;
while(!q.empty()){
int now = q.front();
q.pop();
for(int i = 0; i < 26; i++){
if(trie[now][i]) fail[trie[now][i]] = trie[fail[now]][i], q.push(trie[now][i]);
else trie[now][i] = trie[fail[now]][i];
}
}
}
struct node{
int v, nxt;
}edge[M];
int head[M], num;
int siz[M], son[M], dep[M];
int top[M], dfn[M], cnt;
int seq[M];
inline bool cmp(int a, int b){
return dfn[a] < dfn[b];
}
inline void add_edge(int x, int y){
edge[++num] = (node){y, head[x]};
head[x] = num;
}
inline void dfs1(int x){
dep[x] = dep[fail[x]] + 1, siz[x] = 1;
for(int i = head[x]; i; i = edge[i].nxt){
int y = edge[i].v;
dfs1(y);
siz[x] += siz[y];
if(!son[x] || siz[y] > siz[son[x]])
son[x] = y;
}
}
inline void dfs2(int x, int topfa){
top[x] = topfa, dfn[x] = ++cnt;
if(!son[x]) return;
dfs2(son[x], topfa);
for(int i = head[x]; i; i = edge[i].nxt){
int y = edge[i].v;
if(y == son[x]) continue;
dfs2(y, y);
}
}
inline int lca(int x, int y){
while(top[x] != top[y]){
if(dep[top[x]] < dep[top[y]]) swap(x, y);
x = fail[top[x]];
}
return dep[x] < dep[y] ? x : y;
}
struct BIT{
int c[M];
void add(int x, int y){
for(; x <= tot; x += x & (-x)) c[x] += y;
}
int query(int x){
int res = 0;
for(; x > 0; x -= x & (-x)) res += c[x];
return res;
}
}t;
inline void update(char s[]){
int len = strlen(s + 1);
int now = 1;
for(int i = 1; i <= len; i++){
now = trie[now][s[i] - 'a'];
seq[i] = now;
}
sort(seq + 1, seq + 1 + len, cmp);
for(int i = 1; i <= len; i++) t.add(dfn[seq[i]], 1);
for(int i = 1; i < len; i++) t.add(dfn[lca(seq[i], seq[i + 1])], -1);
}
int main(){
n = read(), m = read();
for(int i = 1; i <= n; i++)
scanf("%s", tmp[i] + 1);
for(int i = 1; i <= m; i++){
scanf("%s", s + 1);
insert(s, i);
}
build();
for(int i = 2; i <= tot; i++)
add_edge(fail[i], i);
dfs1(1), dfs2(1, 1);
for(int i = 1; i <= n; i++)
update(tmp[i]);
for(int i = 1; i <= m; i++){
int x = End[i];
printf("%d\n", t.query(dfn[x] + siz[x] - 1) - t.query(dfn[x] - 1));
}
return 0;
}