[AC自动机 fail树] [NOI2011]阿狸的打字机
题目
题面
大意就是给定若干字符串(输入方式需要用\(trie\)处理) 然后多组询问求第\(x\)个打印的串在第\(y\)个打印的串出现几次。
题解
先建出\(AC\)自动机
考虑第\(x\)个打印的串在第\(y\)个打印的串出现几次,即等价于有多少属于\(y\)的字符串的\(fail\)指针指向\(x\)的尾结点。
多对一不妨变为一对多的考虑,所以建立\(fail\)树。
那么就便成了\(fail\)树上\(x\)的尾结点指向多少属于\(y\)的字符串。
即\(fail\)树上\(x\)的子树中有多少个属于\(y\)的字符串。
那么现在问题就转化成了 给定一棵树,多组询问,求\(x\)子树下第\(y\)种颜色出现的次数(但是一个点可能会有多种颜色 坑!)
所以我们离线处理,在\(trie\)上走,走到\(y\)字符串结束时 计算每个\(x\)子树中的贡献。
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <iostream>
#include <set>
#include <map>
#include <queue>
using namespace std;
template <typename T> void chkmax(T &x, T y) {x = x > y ? x : y;}
template <typename T> void chkmin(T &x, T y) {x = x > y ? y : x;}
typedef long long ll;
const int INF = 2139062143;
#define DEBUG(x) std::cerr << #x << " = " << x << std::endl
template <typename T> void read (T &x) {
x = 0; bool f = 1; char ch;
do {ch = getchar(); if (ch == '-') f = 0;} while (ch > '9' || ch < '0');
do {x = x * 10 + ch - '0'; ch = getchar();} while (ch >= '0' && ch <= '9');
x = f ? x : -x;
}
template <typename T> void write (T x) {
if (x < 0) x = ~x + 1, putchar ('-');
if (x > 9) write (x / 10);
putchar (x % 10 + '0');
}
const int N = 5e5 + 7;
int n, cnt, l, num, E, dfs_clock, in[N], fa[N], ans[N], out[N], dfn[N], head[N], fail[N], word[N], cpy[N][26], trie[N][26];
char s[N], p[N];
string ch[N];
vector < int > v[N];
struct EDGE {
int to, nxt;
} edge[N << 1];
struct Node {
int x, y;
} a[N];
struct BIT {
int sz, c[N];
inline int lowbit(int x) {
return x & -x;
}
inline void add(int i, int x) {
for(; i <= sz; i += lowbit(i)) c[i] += x;
}
inline int sum(int i) {
int ret = 0;
for(; i; i -= lowbit(i)) ret += c[i];
return ret;
}
inline int query(int l, int r) {
return sum(r) - sum(l - 1);
}
} bit;
inline void addedge(int u, int v) {
edge[++E].to = v;
edge[E].nxt = head[u];
head[u] = E;
}
inline void insert(char *s) {
int len = strlen(s), now = 0;
for(int i = 0; i < len; i++) {
if(s[i] == 'B') {now = fa[now]; continue;}
if(s[i] == 'P') {word[++num] = now; continue;}
int v = s[i] - 'a';
if(!trie[now][v]) trie[now][v] = ++ cnt;
fa[trie[now][v]] = now;
now = trie[now][v];
}
memcpy(cpy, trie, sizeof(trie));
}
inline void build() {
queue < int > q;
for(int i = 0; i < 26; i++) {
if(trie[0][i]) {
q.push(trie[0][i]);
fail[trie[0][i]] = 0;
}
}
while(!q.empty()) {
int now = q.front(); q.pop();
for (int i = 0; i < 26; i++) {
if(trie[now][i]) {
q.push(trie[now][i]);
fail[trie[now][i]] = trie[fail[now]][i];
} else {
trie[now][i] = trie[fail[now]][i];
}
}
}
for(int i = 1; i <= cnt; i++) addedge(fail[i], i);
}
inline void dfs(int u) {
in[u] = ++ dfs_clock;
for(int i = head[u]; i; i = edge[i].nxt) dfs(edge[i].to);
out[u] = dfs_clock;
}
inline void solve(int now) {
bit.add(in[now], 1);
for(int i = 0; i < v[now].size(); i++) ans[v[now][i]] = bit.query(in[word[a[v[now][i]].x]], out[word[a[v[now][i]].x]]);
for(int i = 0; i < 26; i++) if(cpy[now][i]) solve(cpy[now][i]);
bit.add(in[now], -1);
}
int main() {
scanf("%s", s); insert(s);
build(); dfs(0); read(n);
bit.sz = dfs_clock;
for(int i = 1; i <= n; i++) {
read(a[i].x); read(a[i].y);
v[word[a[i].y]].push_back(i);
}
solve(0);
for(int i = 1; i <= n; i++) printf("%d\n", ans[i]);
return 0;
}