Codeforces 235C Cyclical Quest (后缀自动机)
思路:一眼看过去,好像处理出每个字串的最小表示的 \(hash\) 值就可以解决了, 但想了复杂度明显过不去,由于要统计某种子串个数,所以首先想到后缀自动机,然后分析,我们将每次查询的模式串翻倍(接在自身后面),模式串的原本长度为 \(n\) ,假设我们现在在后缀自动机上找到了区间 \((le, ri)\) 的子串,首先判断 \(ri - le + 1\) 是否等于 \(n\) , 若相等则加上该节点 \(ednpoints\) 集合大小,然后我们要查询的就是 $(le + 1, ri + 1) $ 的子串了,首先看子串 \((le + 1, ri)\) 是否属于该节点,若不属于,则沿着 \(link\) 链接向上跳,跳到包含子串 \((le + 1, ri)\) 的节点 \(p\) ,然后判断 \(st[p].next[s[ri + 1]]\) 是否存在,若存在,则 \(p\) 跳到 \(p = st[p].next[s[ri + 1]]\) ,否则 \(p\) 直接跳到 \(st[p].link\) , 并更新对应的 \(le\) 。具体看代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxn = 1e5 + 50;
struct state {
int len, link;
int next[26];
};
state st[maxn * 20];
int sz, last;
void sam_init() {
st[0].len = 0;
st[0].link = -1;
sz = 1;
last = 0;
}
LL num[maxn * 20];
void sam_extend(int c) {
int cur = sz++;
st[cur].len = st[last].len + 1;
int p = last;
while (p != -1 && !st[p].next[c]) {
st[p].next[c] = cur;
p = st[p].link;
}
if (p == -1) {
st[cur].link = 0;
} else {
int q = st[p].next[c];
if (st[p].len + 1 == st[q].len) {
st[cur].link = q;
} else {
int clone = sz++;
st[clone].len = st[p].len + 1;
for(int i = 0; i < 26; i++) st[clone].next[i] = st[q].next[i];
st[clone].link = st[q].link;
while (p != -1 && st[p].next[c] == q) {
st[p].next[c] = clone;
p = st[p].link;
}
st[q].link = st[cur].link = clone;
}
}
last = cur;
}
struct Edge
{
int to, next;
} edge[maxn * 40];
int k, head[maxn * 20];
void add(int a, int b){
edge[k].to = b;
edge[k].next = head[a];
head[a] = k++;
}
void dfs(int u, int pre){
for(int i = head[u]; i != -1; i = edge[i].next){
int to = edge[i].to;
if(to == pre) continue;
dfs(to, u);
num[u] += num[to];
}
}
string s, t;
int vis[maxn * 20];
int main(int argc, char const *argv[])
{
cin >> t;
int tlen = t.size();
sam_init();
for(int i = 0; i < tlen; i++){
sam_extend(t[i] - 'a');
num[last] = 1;
}
for(int i = 0; i < sz; i++) head[i] = -1;
for(int i = 1; i < sz; i++){
add(i, st[i].link);
add(st[i].link, i);
}
dfs(0, -1);
int q;
scanf("%d", &q);
int id = 0;
while(q--){
id++;
cin >> s;
int n = s.size();
s += s;
int p = 0;
int le = 0, ri = 0;
LL ans = 0;
while(le < n && ri < 2 * n){
if(st[p].next[s[ri] - 'a']){
p = st[p].next[s[ri] - 'a'];
if(ri - le + 1 == n){
if(vis[p] != id){ // 记录一下该点的贡献已经加过,防止重复算贡献,比如第二个样例
vis[p] = id;
ans += num[p];
}
le++;
while(st[st[p].link].len + 1 > n - 1 && p != 0){
p = st[p].link;
}
}
ri++;
} else {
if(p == 0) le++, ri = le; // 注意,若 p 是节点 0 ,则需要让 le++, 否则会死循环
p = st[p].link;
le = ri - 1 - st[p].len + 1;
}
}
printf("%lld\n", ans);
}
return 0;
}